CompactAI's picture
Upload 2 files
bb2c123 verified
#!/usr/bin/env python3
"""Public-facing TMLM-Haiku interactive CLI.
Pulls models from the CompactAI-O HuggingFace collection:
https://huggingface.co/collections/CompactAI-O/tmlm-haiku-series
"""
from __future__ import annotations
#!/usr/bin/env python3
from __future__ import annotations
import hashlib
import json
import math
import os
import string
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
HUGGINGFACE_MODELS = {
"TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1",
"TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3",
"TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2",
"TMLM-Haiku-2.3": "CompactAI-O/TMLM-Haiku-2.3",
"Glint-1": "CompactAI-O/Glint-1",
}
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
@dataclass
class ModelConfig:
dim: int = 128
n_unique_layers: int = 8
n_logical_layers: int = 16
n_heads: int = 4
n_kv_heads: int = 2
ffn_dim: int = 224
dropout: float = 0.0
seq_len: int = 2048
sliding_window_size: int = 512
mtp_horizons: Tuple[int, ...] = (2, 3, 4)
rope_fraction: float = 0.5
embed_scale: bool = True
logit_soft_cap: float = -1.0
quantization: str = "nvfp4"
@property
def head_dim(self) -> int:
return self.dim // self.n_heads
model_config = ModelConfig()
MODEL_SERIES = {
"haiku": {
"dim": 64,
"n_unique_layers": 12,
"n_logical_layers": 24,
"n_heads": 4,
"n_kv_heads": 2,
"ffn_dim": 384,
"dropout": 0.0,
"seq_len": 2048,
"sliding_window_size": 2048,
"mtp_horizons": (),
"rope_fraction": 0.5,
"engram_dim": 8,
"engram_heads": 2,
"engram_table_size": 64,
"engram_max_ngram": 2,
"mhc_expansion": 2,
"sleep_gate_cap": 0,
"sleep_gate_heads": 4,
"latent_think_layers": 0,
"prelude_layers": 0,
"coda_layers": 0,
"recurrent_loops": 0,
"recurrent_act_threshold": 0.9,
"recurrent_lora_rank": 0,
"recurrent_loop_embed_dim": 0,
},
"sonnet": {
"dim": 1024,
"n_unique_layers": 20,
"n_logical_layers": 40,
"n_heads": 16,
"n_kv_heads": 4,
"ffn_dim": 4096,
"dropout": 0.0,
"seq_len": 2048,
"mtp_horizons": (2,),
"engram_dim": 32,
"engram_heads": 8,
"engram_table_size": 4096,
"engram_max_ngram": 2,
"mhc_expansion": 2,
"sleep_gate_cap": 0,
"sleep_gate_heads": 8,
"latent_think_layers": 0,
"prelude_layers": 0,
"coda_layers": 0,
"recurrent_loops": 0,
"recurrent_act_threshold": 0.99,
"recurrent_lora_rank": 0,
"recurrent_loop_embed_dim": 0,
},
"opus": {
"dim": 1536,
"n_unique_layers": 18,
"n_logical_layers": 36,
"n_heads": 16,
"n_kv_heads": 4,
"ffn_dim": 5888,
"dropout": 0.0,
"seq_len": 2048,
"mtp_horizons": (2,),
"engram_dim": 64,
"engram_heads": 8,
"engram_table_size": 8192,
"engram_max_ngram": 2,
"mhc_expansion": 4,
"sleep_gate_cap": 0,
"sleep_gate_heads": 8,
"latent_think_layers": 0,
"prelude_layers": 0,
"coda_layers": 0,
"recurrent_loops": 0,
"recurrent_act_threshold": 0.99,
"recurrent_lora_rank": 0,
"recurrent_loop_embed_dim": 0,
},
}
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
FORMAT_TOKENS = [
"<|user|>",
"<|assistant|>",
"<|system|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|begin_of_thought|>",
"<|end_of_thought|>",
"<|begin_of_solution|>",
"<|end_of_solution|>",
]
class WordTokenizer:
def __init__(
self, extra_chars: str = "", format_tokens: Optional[List[str]] = None
) -> None:
base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r"
fallback_chars = sorted(set(base + extra_chars))
self.core_special = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"]
self.format_tokens = (
list(format_tokens) if format_tokens else list(FORMAT_TOKENS)
)
self.special = list(self.core_special) + list(self.format_tokens)
self.id_to_token: List[str] = (
list(self.core_special) + self.format_tokens + fallback_chars
)
self.token_to_id: Dict[str, int] = {
t: i for i, t in enumerate(self.id_to_token)
}
self.special_multi_tokens = sorted(
[t for t in self.special if len(t) > 1], key=len, reverse=True
)
self.multi_char_tokens = self.special_multi_tokens
self.dynamic_additions = 0
@property
def pad_id(self) -> int:
return self.token_to_id["<PAD>"]
@property
def bos_id(self) -> int:
return self.token_to_id["<BOS>"]
@property
def eos_id(self) -> int:
return self.token_to_id["<EOS>"]
@property
def unk_id(self) -> int:
return self.token_to_id["<UNK>"]
@property
def vocab_size(self) -> int:
return len(self.id_to_token)
def maybe_add_char(self, ch: str) -> bool:
if ch in self.token_to_id:
return False
self.token_to_id[ch] = len(self.id_to_token)
self.id_to_token.append(ch)
self.dynamic_additions += 1
return True
def iter_lexical_tokens(self, text: str) -> Iterator[str]:
i = 0
n = len(text)
while i < n:
matched_special = False
for token in self.special_multi_tokens:
if text.startswith(token, i):
yield token
i += len(token)
matched_special = True
break
if matched_special:
continue
yield text[i]
i += 1
def encode(
self, text: str, add_bos: bool = False, add_eos: bool = False
) -> List[int]:
out: List[int] = []
if add_bos:
out.append(self.bos_id)
unk = self.unk_id
t2i = self.token_to_id
for tok in self.iter_lexical_tokens(text):
out.append(t2i.get(tok, unk))
if add_eos:
out.append(self.eos_id)
return out
def decode(self, ids: Sequence[int], skip_special: bool = True) -> str:
pieces: List[str] = []
for idx in ids:
if int(idx) < 0 or int(idx) >= len(self.id_to_token):
continue
tok = self.id_to_token[int(idx)]
if skip_special and tok in self.special:
continue
pieces.append(tok)
return "".join(pieces)
@classmethod
def load(cls, path: Path) -> WordTokenizer:
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
format_tokens = data.get("format_tokens", FORMAT_TOKENS)
tokenizer = cls(extra_chars="", format_tokens=format_tokens)
tokenizer.id_to_token = data["id_to_token"]
tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)}
tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens)
tokenizer.special_multi_tokens = sorted(
[t for t in tokenizer.special if len(t) > 1], key=len, reverse=True
)
tokenizer.multi_char_tokens = tokenizer.special_multi_tokens
return tokenizer
LetterTokenizer = WordTokenizer
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(torch.nn.functional, "rms_norm"):
return torch.nn.functional.rms_norm(
x, self.weight.shape, self.weight, self.eps
)
x_fp = x.float()
rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return (x_fp * rms).to(dtype=x.dtype) * self.weight
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, base: float = 10000.0) -> None:
super().__init__()
inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv, persistent=False)
def cos_sin(
self, seq_len: int, device: torch.device, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos()[None, None, :, :].to(dtype=dtype)
sin = emb.sin()[None, None, :, :].to(dtype=dtype)
return cos, sin
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class CausalSelfAttention(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
dropout: float,
sliding_window: int,
rope_fraction: float,
) -> None:
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.n_rep = n_heads // n_kv_heads
self.dropout = dropout
self.sliding_window = sliding_window
self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
self.rope = RotaryEmbedding(self.rope_dim)
self.q_norm = RMSNorm(head_dim)
self.k_norm = RMSNorm(head_dim)
self.output_gate = nn.Parameter(torch.ones(n_heads))
def forward(
self,
x: torch.Tensor,
is_global: bool,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
q = self.q_norm(q)
k = self.k_norm(k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
past_len = past_kv[0].shape[2] if past_kv is not None else 0
cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype)
cos_slice = cos[:, :, past_len : past_len + T, :]
sin_slice = sin[:, :, past_len : past_len + T, :]
q_rope = q[..., : self.rope_dim]
q_pass = q[..., self.rope_dim :]
k_rope = k[..., : self.rope_dim]
k_pass = k[..., self.rope_dim :]
q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice)
k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice)
q = torch.cat([q_rope, q_pass], dim=-1)
k = torch.cat([k_rope, k_pass], dim=-1)
if past_kv is not None:
k = torch.cat([past_kv[0], k], dim=2)
v = torch.cat([past_kv[1], v], dim=2)
new_kv = (k, v) if use_cache else None
S = k.shape[2]
if self.n_rep > 1:
k = (
k[:, :, None, :, :]
.expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
.reshape(B, self.n_heads, S, self.head_dim)
)
v = (
v[:, :, None, :, :]
.expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
.reshape(B, self.n_heads, S, self.head_dim)
)
drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0
if is_global:
if past_kv is None and T > 1:
out = F.scaled_dot_product_attention(
q, k, v, is_causal=True, dropout_p=drop_p
)
else:
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p)
else:
T_q = q.shape[2]
q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1)
k_pos = torch.arange(S, device=q.device).unsqueeze(0)
mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window)
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p
)
gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1)
out = out * gate
out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
out = self.wo(out)
return out, new_kv
class SwiGLU(nn.Module):
def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
super().__init__()
self.gate = nn.Linear(dim, hidden_dim, bias=False)
self.up = nn.Linear(dim, hidden_dim, bias=False)
self.down = nn.Linear(hidden_dim, dim, bias=False)
self.drop = nn.Dropout(dropout)
nn.init.normal_(self.gate.weight, std=dim**-0.5)
nn.init.normal_(self.up.weight, std=dim**-0.5)
nn.init.normal_(self.down.weight, std=hidden_dim**-0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = F.silu(self.gate(x)) * self.up(x)
out = self.down(h)
if self.training and torch.is_grad_enabled():
out = self.drop(out)
return out
def loop_index_embedding(h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0) -> torch.Tensor:
if loop_dim <= 0:
return h
loop_dim = min(loop_dim, h.shape[-1])
if loop_dim % 2 == 1:
loop_dim -= 1
if loop_dim <= 0:
return h
inv_freq = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
phase = torch.tensor(float(loop_t), device=h.device, dtype=h.dtype) * inv_freq
loop_embed = torch.cat([phase.sin(), phase.cos()], dim=0).view(1, 1, loop_dim)
out = h.clone()
out[..., :loop_dim] = out[..., :loop_dim] + loop_embed
return out
class DepthLoRAAdapter(nn.Module):
def __init__(self, dim: int, rank: int, max_loops: int) -> None:
super().__init__()
self.rank = max(0, rank)
if self.rank <= 0:
self.down = None
self.B = None
self.scale = None
return
self.down = nn.Linear(dim, self.rank, bias=False)
self.B = nn.Parameter(torch.randn(self.rank, dim) * 0.02)
self.scale = nn.Embedding(max(1, max_loops), self.rank)
nn.init.zeros_(self.scale.weight)
def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor:
if self.rank <= 0 or self.down is None or self.B is None or self.scale is None:
return torch.zeros_like(x)
t_idx = min(loop_t, self.scale.num_embeddings - 1)
scale = self.scale(torch.tensor(t_idx, device=x.device))
return (self.down(x) * scale) @ self.B
class StableRecurrentInjection(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.log_A = nn.Parameter(torch.full((dim,), -2.0))
self.log_dt = nn.Parameter(torch.full((dim,), -2.0))
self.input_gate = nn.Parameter(torch.zeros(dim))
def forward(self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor) -> torch.Tensor:
A = torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))).view(1, 1, -1)
B = torch.sigmoid(self.input_gate).view(1, 1, -1)
return A * h + B * e + transformer_out
class AdaptiveHalting(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.halt = nn.Linear(dim, 1, bias=True)
nn.init.zeros_(self.halt.weight)
nn.init.constant_(self.halt.bias, -2.0)
def forward(self, h: torch.Tensor) -> torch.Tensor:
return torch.sigmoid(self.halt(h)).squeeze(-1)
class EngramBlock(nn.Module):
"""DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup.
Stores common token-pair/triplet patterns in an embedding table and
retrieves them with multi-head hashing. A context-aware gate (using the
current hidden state as query) decides how much of the retrieved memory
to inject into the residual stream.
Reference: DeepSeek-AI, "Conditional Memory via Scalable Lookup" (2025).
"""
def __init__(
self,
dim: int,
engram_dim: int,
n_heads: int = 4,
table_size: int = 8192,
max_ngram: int = 3,
) -> None:
super().__init__()
self.dim = dim
self.engram_dim = engram_dim
self.n_heads = n_heads
self.table_size = table_size
self.max_ngram = max_ngram
# One embedding table per (ngram_order, hash_head)
self.embeddings = nn.ParameterDict()
for n in range(2, max_ngram + 1):
for k in range(n_heads):
self.embeddings[f"{n}_{k}"] = nn.Parameter(
torch.randn(table_size, engram_dim) * (engram_dim**-0.5)
)
# Fixed hash parameters (non-learnable, deterministic)
for n in range(2, max_ngram + 1):
for k in range(n_heads):
seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16)
rng = torch.Generator().manual_seed(seed)
a = torch.randint(1, 2**31, (1,), generator=rng).item()
b = torch.randint(0, 2**31, (1,), generator=rng).item()
self.register_buffer(
f"hash_a_{n}_{k}", torch.tensor(a), persistent=False
)
self.register_buffer(
f"hash_b_{n}_{k}", torch.tensor(b), persistent=False
)
# Causal convolution over N-gram branch outputs (kernel=4, dilation=max_ngram)
total_branch_dim = engram_dim * n_heads * (max_ngram - 1)
self.branch_conv = nn.Conv1d(
total_branch_dim,
total_branch_dim,
kernel_size=4,
dilation=max_ngram,
padding=0,
groups=total_branch_dim,
bias=True,
)
nn.init.zeros_(self.branch_conv.weight)
nn.init.zeros_(self.branch_conv.bias)
# Context-aware gating: hidden state as query, memory as key/value
self.gate_query = nn.Linear(dim, engram_dim, bias=False)
self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False)
self.gate_value = nn.Linear(total_branch_dim, dim, bias=False)
self.gate_scale = engram_dim**-0.5
def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor:
"""Hash n-gram token sequences into table indices.
Args:
token_ids: (B, T) token IDs
n: n-gram order (2 = bigram, 3 = trigram)
k: hash head index
Returns:
indices: (B, T) integer indices into embedding table
"""
a = getattr(self, f"hash_a_{n}_{k}")
b = getattr(self, f"hash_b_{n}_{k}")
B, T = token_ids.shape
# Pad left with zeros so every position has a valid n-gram
padded = F.pad(token_ids, (n - 1, 0), value=0) # (B, T+n-1)
# Polynomial rolling hash
combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device)
for i in range(n):
combined = combined * 31 + padded[:, i : i + T].long()
indices = ((a * combined) ^ b) % self.table_size
return indices
def forward(
self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward pass.
Args:
hidden: (B, T, dim) current hidden state
token_ids: (B, T) input token IDs for n-gram hashing.
If None, uses argmax of hidden projections as proxy.
Returns:
output: (B, T, dim) memory injection for residual stream
"""
B, T, _ = hidden.shape
if token_ids is None:
# Fallback: derive pseudo-token-ids from hidden state
token_ids = hidden.mean(dim=-1).long() % self.table_size
# Retrieve and concatenate across n-gram orders and hash heads
branch_outputs = []
for n in range(2, self.max_ngram + 1):
for k in range(self.n_heads):
indices = self._hash_ngram(token_ids, n, k) # (B, T)
table = self.embeddings[f"{n}_{k}"] # (table_size, engram_dim)
retrieved = table[indices] # (B, T, engram_dim)
branch_outputs.append(retrieved)
# (B, T, engram_dim * n_heads * (max_ngram - 1))
memory = torch.cat(branch_outputs, dim=-1)
# Causal convolution over sequence dimension
# Pad left for causality (kernel_size - 1 = 3)
conv_in = memory.transpose(1, 2) # (B, C, T)
conv_in = F.pad(
conv_in,
((self.branch_conv.kernel_size[0] - 1) * self.branch_conv.dilation[0], 0),
)
conv_out = self.branch_conv(conv_in) # (B, C, T)
memory = conv_out.transpose(1, 2) # (B, T, C)
# Context-aware gating
query = self.gate_query(hidden) # (B, T, engram_dim)
key = self.gate_key(memory) # (B, T, engram_dim)
gate = torch.sigmoid(
(query * key).sum(dim=-1, keepdim=True) * self.gate_scale
) # (B, T, 1)
value = self.gate_value(memory) # (B, T, dim)
return gate * value
class SleepGate(nn.Module):
"""Persistent memory + periodic consolidation gate."""
def __init__(
self,
dim: int,
cap: int = 128,
n_heads: int = 4,
retention_enabled: bool = True,
retention_hidden: int = 0,
) -> None:
super().__init__()
self.dim = dim
self.cap = cap
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.scale = self.head_dim ** -0.5
self.retention_enabled = retention_enabled
self.register_buffer("mem_emb", torch.zeros(cap, dim, dtype=torch.bfloat16))
self.register_buffer("mem_age", torch.zeros(cap, dtype=torch.long))
self.register_buffer("mem_beta", torch.ones(cap, dtype=torch.float32))
self.register_buffer("mem_count", torch.zeros((), dtype=torch.long))
self.register_buffer("mem_head", torch.zeros((), dtype=torch.long))
self.register_buffer("global_step", torch.zeros((), dtype=torch.long))
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.o_proj = nn.Linear(dim, dim, bias=False)
nn.init.zeros_(self.o_proj.weight)
self.gate_scale = nn.Parameter(torch.zeros(()))
if retention_enabled:
if retention_hidden > 0:
self.retention_gate: Optional[nn.Module] = nn.Sequential(
nn.Linear(dim, retention_hidden, bias=False),
nn.GELU(),
nn.Linear(retention_hidden, 1, bias=True),
)
nn.init.constant_(self.retention_gate[-1].bias, 2.2)
else:
self.retention_gate = nn.Linear(dim, 1, bias=True)
nn.init.constant_(self.retention_gate.bias, 2.2)
else:
self.retention_gate = None
self._last_beta: Optional[torch.Tensor] = None
def write(self, hidden: torch.Tensor) -> None:
B, T, _ = hidden.shape
tail_full = hidden[:, max(0, T - 16):, :].float().mean(dim=1)
if self.retention_gate is not None:
beta_live = torch.sigmoid(self.retention_gate(tail_full).squeeze(-1))
self._last_beta = beta_live if self.training else None
beta_store = beta_live.detach().float()
else:
self._last_beta = None
beta_store = torch.ones(B, device=hidden.device, dtype=torch.float32)
tail = tail_full.to(self.mem_emb.dtype).detach()
with torch.no_grad():
head = int(self.mem_head.item())
count = int(self.mem_count.item())
step = int(self.global_step.item())
for b in range(B):
self.mem_emb[head] = tail[b]
self.mem_age[head] = step
self.mem_beta[head] = beta_store[b]
head = (head + 1) % self.cap
if count < self.cap:
count += 1
self.mem_head.fill_(head)
self.mem_count.fill_(count)
def read(self, x: torch.Tensor) -> torch.Tensor:
count = int(self.mem_count.item())
if count == 0:
return torch.zeros_like(x)
B, T, D = x.shape
mem = self.mem_emb[:count].clone().to(x.dtype)
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1)
v = self.v_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1)
attn = torch.einsum("bhtd,hmd->bhtm", q, k) * self.scale
attn = F.softmax(attn, dim=-1)
if self.retention_enabled:
step = int(self.global_step.item())
ages = self.mem_age[:count].to(x.device)
delta = (step - ages).clamp(min=0).to(x.dtype)
betas = self.mem_beta[:count].to(x.dtype).clamp(min=1e-6, max=1.0)
weights = betas.pow(delta)
attn = attn * weights.view(1, 1, 1, count)
attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9)
out = torch.einsum("bhtm,hmd->bhtd", attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, D)
out = self.o_proj(out)
return torch.sigmoid(self.gate_scale) * out
@torch.no_grad()
def reset(self) -> None:
self.mem_emb.zero_()
self.mem_age.zero_()
self.mem_beta.fill_(1.0)
self.mem_count.zero_()
self.mem_head.zero_()
self.global_step.zero_()
self._last_beta = None
def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
M = torch.exp(logits.clamp(-10, 10))
for _ in range(n_iters):
M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10)
M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10)
return M
class ManifoldHyperConnection(nn.Module):
def __init__(self, dim: int, expansion: int = 2) -> None:
super().__init__()
self.dim = dim
self.expansion = expansion
n = expansion
self.expand_fn = "duplicate"
self.collapse_fn = "mean"
self.bias_pre = nn.Parameter(torch.zeros(1, n))
self.bias_post = nn.Parameter(torch.zeros(1, n))
self.bias_res = nn.Parameter(torch.zeros(n, n))
self.theta_pre = nn.Linear(n * dim, n, bias=False)
self.theta_post = nn.Linear(n * dim, n, bias=False)
self.theta_res = nn.Linear(n * dim, n * n, bias=False)
self.alpha_pre = nn.Parameter(torch.tensor(0.0))
self.alpha_post = nn.Parameter(torch.tensor(0.0))
self.alpha_res = nn.Parameter(torch.tensor(0.0))
def _compute_mappings(
self, x_expanded: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, _ = x_expanded.shape
n = self.expansion
x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]])
d_pre = torch.tanh(self.theta_pre(x_norm))
d_post = torch.tanh(self.theta_post(x_norm))
d_res = self.theta_res(x_norm)
H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre)
H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post)
H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape(
B, T, n, n
)
H_res = _sinkhorn_knopp(H_res_raw)
return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res
def expand_stream(self, x: torch.Tensor) -> torch.Tensor:
return x.repeat(1, 1, self.expansion)
def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor:
B, T, _ = x_expanded.shape
n = self.expansion
C = self.dim
return x_expanded.view(B, T, n, C).mean(dim=-2)
def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor:
B, T, _ = x_expanded.shape
n = self.expansion
x_streams = x_expanded.view(B, T, n, self.dim)
return (H_pre @ x_streams).squeeze(-2)
def post_res_mix(
self,
layer_output: torch.Tensor,
x_expanded: torch.Tensor,
H_post: torch.Tensor,
H_res: torch.Tensor,
) -> torch.Tensor:
B, T, _ = x_expanded.shape
n = self.expansion
C = self.dim
x_streams = x_expanded.view(B, T, n, C)
mixed = torch.matmul(H_res, x_streams)
post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2))
result = mixed + post_out
return result.reshape(B, T, n * C)
class TransformerBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
ffn_dim: int,
dropout: float,
sliding_window: int,
rope_fraction: float,
engram_dim: int = 0,
engram_heads: int = 4,
engram_table_size: int = 8192,
engram_max_ngram: int = 3,
mhc_expansion: int = 1,
) -> None:
super().__init__()
self.norm1 = RMSNorm(dim)
self.attn = CausalSelfAttention(
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
dropout=dropout,
sliding_window=sliding_window,
rope_fraction=rope_fraction,
)
self.norm2 = RMSNorm(dim)
self.ffn = SwiGLU(dim, ffn_dim, dropout)
self.use_engram = engram_dim > 0
if self.use_engram:
self.engram = EngramBlock(
dim=dim,
engram_dim=engram_dim,
n_heads=engram_heads,
table_size=engram_table_size,
max_ngram=engram_max_ngram,
)
self.engram_norm = RMSNorm(dim)
self.use_mhc = mhc_expansion > 1
if self.use_mhc:
self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
def forward(
self,
x: torch.Tensor,
is_global: bool,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
token_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
if self.use_mhc:
x_exp = self.mhc_attn.expand_stream(x)
H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp)
attn_in = self.mhc_attn.pre_mix(x_exp, H_pre)
attn_out, new_kv = self.attn(
self.norm1(attn_in), is_global, past_kv, use_cache
)
x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res)
if self.use_engram:
collapsed = self.mhc_attn.collapse_stream(x_exp)
collapsed = collapsed + self.engram(
self.engram_norm(collapsed), token_ids=token_ids
)
x_exp = self.mhc_attn.expand_stream(collapsed)
H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp)
ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2)
ffn_out = self.ffn(self.norm2(ffn_in))
x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2)
x = self.mhc_attn.collapse_stream(x_exp)
else:
attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache)
x = x + attn_out
if self.use_engram:
x = x + self.engram(self.engram_norm(x), token_ids=token_ids)
x = x + self.ffn(self.norm2(x))
return x, new_kv
class RecurrentDepthBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
ffn_dim: int,
dropout: float,
sliding_window: int,
rope_fraction: float,
n_loops: int,
act_threshold: float,
lora_rank: int,
loop_embed_dim: int,
) -> None:
super().__init__()
self.n_loops = max(1, n_loops)
self.act_threshold = act_threshold
self.loop_embed_dim = max(0, loop_embed_dim)
self.norm = RMSNorm(dim)
self.block = TransformerBlock(
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1,
)
self.injection = StableRecurrentInjection(dim)
self.act = AdaptiveHalting(dim)
self.lora = DepthLoRAAdapter(dim, lora_rank, self.n_loops)
def forward(
self,
h: torch.Tensor,
e: torch.Tensor,
token_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
use_cache: bool = False,
n_loops: Optional[int] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
loops = max(1, n_loops or self.n_loops)
B, T, _ = h.shape
halted = torch.zeros(B, T, device=h.device, dtype=torch.bool)
cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype)
output = torch.zeros_like(h)
new_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
current = h
final_halt = None
for t in range(loops):
h_loop = loop_index_embedding(current, t, self.loop_embed_dim)
combined = self.norm(h_loop + e)
past_kv = None
if past_key_values is not None and t < len(past_key_values):
past_kv = past_key_values[t]
trans_out, layer_kv = self.block(combined, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=token_ids)
trans_out = trans_out + self.lora(trans_out, t)
next_h = self.injection(current, e, trans_out)
p = self.act(next_h)
p = p * (~halted).to(p.dtype)
final_halt = p
should_halt = (~halted) & ((cumulative_p + p) >= self.act_threshold)
update_weight = torch.where(should_halt, (1.0 - cumulative_p).clamp(min=0.0), p)
output = output + next_h * update_weight.unsqueeze(-1)
cumulative_p = cumulative_p + update_weight
current = torch.where(halted.unsqueeze(-1), current, next_h)
halted = halted | should_halt
if new_past is not None:
new_past.append(layer_kv)
if not use_cache and bool(halted.all()):
break
remainder = (1.0 - cumulative_p).clamp(min=0.0)
output = output + current * remainder.unsqueeze(-1)
aux: Dict[str, torch.Tensor] = {}
if final_halt is not None:
aux["recurrent_halt_mean"] = final_halt.mean()
return output, aux, new_past
class TinyMemoryLM(nn.Module):
def __init__(
self,
vocab_size: int,
dim: int,
n_unique_layers: int,
n_logical_layers: int,
n_heads: int,
n_kv_heads: int,
ffn_dim: int,
dropout: float,
mtp_horizons: Sequence[int],
grad_checkpoint: bool,
sliding_window: int = 512,
rope_fraction: float = 0.5,
embed_scale: bool = True,
engram_dim: int = 0,
engram_heads: int = 4,
engram_table_size: int = 8192,
engram_max_ngram: int = 3,
mhc_expansion: int = 1,
sleep_gate_cap: int = 0,
sleep_gate_heads: int = 4,
sleep_retention_enabled: bool = True,
sleep_retention_hidden: int = 0,
latent_think_layers: int = 0,
prelude_layers: int = 0,
coda_layers: int = 0,
recurrent_loops: int = 0,
recurrent_act_threshold: float = 0.99,
recurrent_lora_rank: int = 0,
recurrent_loop_embed_dim: int = 0,
) -> None:
super().__init__()
self.dim = dim
self.n_unique_layers = n_unique_layers
self.n_logical_layers = n_logical_layers
self.grad_checkpoint = grad_checkpoint
self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0
head_dim = dim // n_heads
self.embed_tokens = nn.Embedding(vocab_size, dim)
self.head = nn.Linear(dim, vocab_size, bias=False)
self.head.weight = self.embed_tokens.weight
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
self.use_recurrent_depth = recurrent_loops > 0
self.prelude_layers = max(0, prelude_layers)
self.coda_layers = max(0, coda_layers)
self.recurrent_loops = max(0, recurrent_loops)
self.blocks: Optional[nn.ModuleList] = None
self.prelude: Optional[nn.ModuleList] = None
self.recurrent: Optional[RecurrentDepthBlock] = None
self.coda: Optional[nn.ModuleList] = None
def _make_blocks(n: int) -> nn.ModuleList:
return nn.ModuleList([
TransformerBlock(
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
rope_fraction=rope_fraction, engram_dim=engram_dim,
engram_heads=engram_heads, engram_table_size=engram_table_size,
engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion,
)
for _ in range(n)
])
if self.use_recurrent_depth:
if self.prelude_layers > 0:
self.prelude = _make_blocks(self.prelude_layers)
self.recurrent = RecurrentDepthBlock(
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window,
rope_fraction=rope_fraction, n_loops=self.recurrent_loops,
act_threshold=recurrent_act_threshold, lora_rank=recurrent_lora_rank,
loop_embed_dim=recurrent_loop_embed_dim or max(2, dim // 8),
)
if self.coda_layers > 0:
self.coda = _make_blocks(self.coda_layers)
else:
self.blocks = _make_blocks(max(1, n_unique_layers))
self.norm = RMSNorm(dim)
self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
self.mtp_adapters = nn.ModuleDict(
{str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons}
)
self.mtp_norms = nn.ModuleDict(
{str(h): RMSNorm(dim) for h in self.mtp_horizons}
)
res_scale = (2 * max(1, n_logical_layers)) ** -0.5
for group in (self.blocks, self.prelude, self.coda):
if group is None:
continue
for block in group:
block.attn.wo.weight.data.mul_(res_scale)
block.ffn.down.weight.data.mul_(res_scale)
if self.recurrent is not None:
self.recurrent.block.attn.wo.weight.data.mul_(res_scale)
self.recurrent.block.ffn.down.weight.data.mul_(res_scale)
self.sleep_gate: Optional[SleepGate] = None
if sleep_gate_cap > 0:
self.sleep_gate = SleepGate(
dim=dim, cap=sleep_gate_cap, n_heads=sleep_gate_heads,
retention_enabled=sleep_retention_enabled,
retention_hidden=sleep_retention_hidden,
)
self.think_blocks: Optional[nn.ModuleList] = None
self.think_norm: Optional[RMSNorm] = None
if latent_think_layers > 0:
self.think_blocks = nn.ModuleList([
TransformerBlock(
dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim,
ffn_dim=ffn_dim, dropout=0.0, sliding_window=2048,
rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1,
)
for _ in range(latent_think_layers)
])
self.think_norm = RMSNorm(dim)
def resize_token_embeddings(self, new_vocab_size: int) -> None:
old_vocab_size = self.embed_tokens.num_embeddings
if new_vocab_size == old_vocab_size:
return
device = self.embed_tokens.weight.device
old_embed_weight = self.embed_tokens.weight.data.clone()
self.embed_tokens = nn.Embedding(new_vocab_size, self.embed_tokens.embedding_dim).to(device)
self.head = nn.Linear(self.embed_tokens.embedding_dim, new_vocab_size, bias=False).to(device)
self.head.weight = self.embed_tokens.weight
old_bias = self.output_bias.data.clone()
self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
copy_size = min(old_vocab_size, new_vocab_size)
self.output_bias.data[:copy_size] = old_bias[:copy_size]
self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]
def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
if self.blocks is None:
return []
blocks_list = list(self.blocks)
full_sequence = blocks_list + blocks_list
return [(block, i) for i, block in enumerate(full_sequence[: self.n_logical_layers])]
def forward(
self,
ids: torch.Tensor,
use_cache: bool = False,
past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
return_hidden: bool = False,
) -> Tuple[torch.Tensor, Dict[int, torch.Tensor], Dict[str, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
B, T = ids.shape
x = self.embed_tokens(ids) * self.embed_scale_factor
new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
aux: Dict[str, torch.Tensor] = {}
if self.use_recurrent_depth:
offset = 0
if self.prelude is not None:
for block in self.prelude:
past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None
x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids)
if new_past_key_values is not None:
new_past_key_values.append(layer_kv)
offset += 1
encoded = x
recurrent_past = past_key_values[offset: offset + self.recurrent_loops] if past_key_values is not None else None
x, recurrent_aux, recurrent_kv = self.recurrent(
x, encoded, token_ids=ids, past_key_values=recurrent_past, use_cache=use_cache,
)
aux.update(recurrent_aux)
if new_past_key_values is not None and recurrent_kv is not None:
new_past_key_values.extend(recurrent_kv)
offset += self.recurrent_loops
if self.coda is not None:
for block in self.coda:
past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None
x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids)
if new_past_key_values is not None:
new_past_key_values.append(layer_kv)
offset += 1
else:
logical_layers = self._build_logical_layers()
last_logical_idx = len(logical_layers) - 1
for layer_idx, (block, logical_idx) in enumerate(logical_layers):
is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx
past_kv = past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None
if self.grad_checkpoint and self.training and not use_cache:
x, layer_kv = checkpoint(block, x, is_global, past_kv, use_cache, ids, use_reentrant=True)
else:
x, layer_kv = block(x, is_global, past_kv, use_cache, ids)
if new_past_key_values is not None:
new_past_key_values.append(layer_kv)
x = self.norm(x)
if self.sleep_gate is not None:
x = x + self.sleep_gate.read(x)
if self.training:
self.sleep_gate.write(x)
if self.think_blocks is not None:
for think_block in self.think_blocks:
x, _ = think_block(x, is_global=True)
x = self.think_norm(x)
h_out = x if return_hidden else None
logits = self.head(x)
if self.embed_scale_factor != 1.0:
logits = logits / self.embed_scale_factor
logits = logits + self.output_bias
mtp: Dict[int, torch.Tensor] = {}
if self.mtp_horizons and self.training:
for horizon in self.mtp_horizons:
if horizon > 1 and horizon <= T - 1:
shifted_h = x[:, :-horizon, :]
adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
adapted_h = self.mtp_norms[str(horizon)](adapted_h)
mtp_logits = self.head(adapted_h)
if self.embed_scale_factor != 1.0:
mtp_logits = mtp_logits / self.embed_scale_factor
mtp_logits = mtp_logits + self.output_bias
mtp[horizon] = mtp_logits
return logits, mtp, aux, h_out, new_past_key_values
# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------
def build_stop_token_ids(tokenizer: WordTokenizer) -> set:
stop_tokens = {tokenizer.eos_id}
for tok in ("<|user|>", "<|system|>", "<|assistant|>"):
tid = tokenizer.token_to_id.get(tok)
if tid is not None:
stop_tokens.add(int(tid))
return stop_tokens
def apply_no_repeat_ngram(
logits: torch.Tensor,
token_history: Sequence[int],
ngram_size: int,
) -> torch.Tensor:
if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1):
return logits
prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple()
banned: set = set()
for i in range(len(token_history) - ngram_size + 1):
if tuple(token_history[i : i + ngram_size - 1]) == prefix:
banned.add(int(token_history[i + ngram_size - 1]))
if not banned:
return logits
out = logits.clone()
banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long)
out[banned_ids] = float("-inf")
return out
def apply_loop_penalty(
logits: torch.Tensor,
tokenizer: WordTokenizer,
generated_text: str,
penalty: float = 5.0,
) -> torch.Tensor:
"""Detect repeated substring loops and penalise continuation tokens."""
if len(generated_text) < 16:
return logits
out = logits.clone()
for span_len in [24, 16, 12, 8]:
if len(generated_text) < span_len * 2:
continue
suffix = generated_text[-span_len:]
prev = generated_text[:-span_len].rfind(suffix)
if prev == -1:
continue
next_pos = prev + span_len
if next_pos < len(generated_text):
next_char = generated_text[next_pos]
tid = tokenizer.token_to_id.get(next_char)
if tid is not None:
out[tid] -= penalty
break
return out
def apply_min_p(logits: torch.Tensor, min_p: float) -> torch.Tensor:
"""Filter tokens below min_p fraction of the top token probability."""
if min_p <= 0.0:
return logits
probs = torch.softmax(logits, dim=-1)
threshold = probs.max() * min_p
out = logits.clone()
out[probs < threshold] = float("-inf")
return out
def generate(
model: TinyMemoryLM,
tokenizer: WordTokenizer,
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.8,
top_k: int = 16,
top_p: float = 0.95,
repetition_penalty: float = 1.0,
device: str = "cuda",
sft_mode: bool = True,
stream: bool = True,
no_repeat_ngram_size: int = 0,
context_window: int = 2048,
logit_soft_cap: float = 15.0,
min_p: float = 0.05,
loop_penalty: float = 5.0,
) -> str:
if sft_mode:
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
else:
full_prompt = prompt
input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device)
visible_tokens: List[str] = []
stop_token_ids = build_stop_token_ids(tokenizer)
generated_text = ""
generated_ids: List[int] = []
# Full history (prompt + generated) for ngram blocking β€” prevents echoing prompt
full_ids_history: List[int] = list(input_ids)
with torch.no_grad():
for _ in range(max_new_tokens):
ctx_ids = (
input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
)
logits, *_ = model(ctx_ids)
next_logits = logits[0, -1, :].clone()
# Logit soft-capping (Gemma-style) β€” prevents overconfident collapse
if logit_soft_cap > 0:
next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap)
raw_next_logits = next_logits.clone()
# Repetition penalty on previously generated tokens
if repetition_penalty != 1.0 and generated_ids:
for tok_id in set(generated_ids):
if next_logits[tok_id] > 0:
next_logits[tok_id] /= repetition_penalty
else:
next_logits[tok_id] *= repetition_penalty
# No-repeat n-gram blocking on generated tokens only
if no_repeat_ngram_size > 0 and generated_ids:
next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size)
# Substring loop detection
next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty)
# Temperature scaling
if temperature != 1.0:
next_logits = next_logits / max(temperature, 1e-6)
# Min-p filtering β€” remove tokens below min_p * max_prob
if min_p > 0:
next_logits = apply_min_p(next_logits, min_p)
# Top-k filtering
if top_k > 0:
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0)))
next_logits[next_logits < v[-1]] = float("-inf")
# Top-p (nucleus) filtering
if 0.0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
remove_mask = cumulative_probs > top_p
remove_mask[0] = False
indices_to_remove = sorted_indices[remove_mask]
next_logits[indices_to_remove] = float("-inf")
# Fallback if all tokens masked
if not torch.isfinite(next_logits).any():
next_logits = raw_next_logits
if temperature != 1.0:
next_logits = next_logits / max(temperature, 1e-6)
if temperature == 0:
next_id = torch.argmax(next_logits).item()
else:
probs = torch.softmax(next_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1).item()
if next_id in stop_token_ids:
break
token_str = (
tokenizer.id_to_token[next_id]
if next_id < len(tokenizer.id_to_token)
else ""
)
generated_ids.append(next_id)
full_ids_history.append(next_id)
if token_str not in tokenizer.special:
visible_tokens.append(token_str)
generated_text += token_str
if stream:
print(token_str, end="", flush=True)
input_ids_t = torch.cat(
[input_ids_t, torch.tensor([[next_id]], device=device)], dim=1
)
if stream:
print()
return "".join(visible_tokens)
def generate_stream(
model: TinyMemoryLM,
tokenizer: WordTokenizer,
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.8,
top_k: int = 16,
top_p: float = 0.95,
repetition_penalty: float = 1.0,
device: str = "cpu",
sft_mode: bool = True,
no_repeat_ngram_size: int = 0,
context_window: int = 2048,
logit_soft_cap: float = 15.0,
min_p: float = 0.05,
loop_penalty: float = 5.0,
) -> "Iterator[str]":
"""Yield the accumulated response string after each new token (for Gradio streaming)."""
if sft_mode:
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
else:
full_prompt = prompt
input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device)
stop_token_ids = build_stop_token_ids(tokenizer)
generated_ids: list = []
generated_text = ""
with torch.no_grad():
for _ in range(max_new_tokens):
ctx_ids = input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
logits, *_ = model(ctx_ids)
next_logits = logits[0, -1, :].clone()
if logit_soft_cap > 0:
next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap)
raw_next_logits = next_logits.clone()
if repetition_penalty != 1.0 and generated_ids:
for tok_id in set(generated_ids):
if next_logits[tok_id] > 0:
next_logits[tok_id] /= repetition_penalty
else:
next_logits[tok_id] *= repetition_penalty
if no_repeat_ngram_size > 0 and generated_ids:
next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size)
next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty)
if temperature != 1.0:
next_logits = next_logits / max(temperature, 1e-6)
if min_p > 0:
next_logits = apply_min_p(next_logits, min_p)
if top_k > 0:
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0)))
next_logits[next_logits < v[-1]] = float("-inf")
if 0.0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
remove_mask = cumulative_probs > top_p
remove_mask[0] = False
next_logits[sorted_indices[remove_mask]] = float("-inf")
if not torch.isfinite(next_logits).any():
next_logits = raw_next_logits
if temperature != 1.0:
next_logits = next_logits / max(temperature, 1e-6)
if temperature == 0:
next_id = int(torch.argmax(next_logits).item())
else:
probs = torch.softmax(next_logits, dim=-1)
next_id = int(torch.multinomial(probs, num_samples=1).item())
if next_id in stop_token_ids:
break
token_str = tokenizer.id_to_token[next_id] if next_id < len(tokenizer.id_to_token) else ""
generated_ids.append(next_id)
if token_str not in tokenizer.special:
generated_text += token_str
yield generated_text
input_ids_t = torch.cat(
[input_ids_t, torch.tensor([[next_id]], device=device)], dim=1
)
# ---------------------------------------------------------------------------
# Local model loading
# ---------------------------------------------------------------------------
def series_from_name(name: str) -> str | None:
lower = (name or "").lower()
if "haiku" in lower:
return "Haiku"
if "sonnet" in lower:
return "Sonnet"
if "opus" in lower:
return "Opus"
return None
def series_config(series: str) -> dict[str, object]:
return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"])
def discover_models(runs_dir: Path) -> List[dict]:
models = []
if not runs_dir.is_dir():
return models
for child in sorted(runs_dir.iterdir()):
if not child.is_dir():
continue
tokenizer_path = child / "tokenizer.json"
if not tokenizer_path.exists():
continue
name = child.name
series = None
for ckpt_name in ("model.pt", "pretrain.pt"):
ckpt_path = child / ckpt_name
if ckpt_path.exists():
series = _fast_series_from_checkpoint(ckpt_path)
break
if series is None:
series = series_from_name(name) or "Sonnet"
found = False
for ckpt_name in ("model.pt", "model_rep.pt", "pretrain.pt"):
ckpt_path = child / ckpt_name
if ckpt_path.exists():
models.append(
{
"name": name,
"checkpoint": ckpt_name,
"series": series,
"model_path": ckpt_path,
"tokenizer_path": tokenizer_path,
}
)
found = True
if not found:
step_ckpts = sorted(
child.glob("checkpoint_step_*.pt"),
key=lambda p: int(p.stem.rsplit("_", 1)[-1]),
)
if step_ckpts:
ckpt_path = step_ckpts[-1]
models.append(
{
"name": name,
"checkpoint": ckpt_path.name,
"series": series,
"model_path": ckpt_path,
"tokenizer_path": tokenizer_path,
}
)
return models
def _detect_engram(state_dict):
for key in state_dict:
if ".engram." in key:
if ".embeddings." in key:
return state_dict[key].shape[-1]
return 0
def _detect_mhc(state_dict):
for key, val in state_dict.items():
if ".mhc_attn.bias_pre" in key and val.dim() == 2:
return val.shape[-1] # (1, expansion)
return 1
def _detect_sleep_gate(state_dict) -> Tuple[int, int]:
for key, val in state_dict.items():
if key == "sleep_gate.mem_emb" and val.dim() == 2:
cap = val.shape[0]
return cap, 4
return 0, 4
def _detect_latent_think(state_dict) -> int:
indices = {
int(k.split(".")[1])
for k in state_dict
if k.startswith("think_blocks.") and k.split(".")[1].isdigit()
}
return max(indices) + 1 if indices else 0
def _detect_prelude_layers(state_dict) -> int:
indices = {
int(k.split(".")[1])
for k in state_dict
if k.startswith("prelude.") and k.split(".")[1].isdigit()
}
return max(indices) + 1 if indices else 0
def _detect_coda_layers(state_dict) -> int:
indices = {
int(k.split(".")[1])
for k in state_dict
if k.startswith("coda.") and k.split(".")[1].isdigit()
}
return max(indices) + 1 if indices else 0
def _detect_recurrent_loops(state_dict) -> int:
if "recurrent.norm.weight" in state_dict or "recurrent.block.attn.wq.weight" in state_dict:
if "recurrent.lora.scale.weight" in state_dict:
return state_dict["recurrent.lora.scale.weight"].shape[0]
return 1
return 0
def _detect_recurrent_lora_rank(state_dict) -> int:
for key in ("recurrent.lora.B", "recurrent.lora.down.weight"):
if key in state_dict:
shape = state_dict[key].shape
if len(shape) == 2:
return int(shape[0])
return 0
def _infer_series_from_lora_rank(rank: int) -> str | None:
if rank == 0:
return None
if rank <= 8:
return "haiku"
if rank <= 16:
return "sonnet"
return "opus"
def _fast_series_from_checkpoint(ckpt_path: Path) -> str | None:
try:
cp = torch.load(ckpt_path, map_location="cpu", weights_only=False)
sd = cp.get("model_state", cp.get("state_dict", {}))
rank = 0
for key in ("recurrent.lora.B", "recurrent.lora.down.weight"):
if key in sd:
rank = int(sd[key].shape[0])
break
if rank == 0:
return None
if rank <= 8:
return "Haiku"
if rank <= 16:
return "Sonnet"
return "Opus"
except Exception:
pass
return None
def _infer_arch_from_state_dict(state_dict, cfg):
"""Infer architecture hyper-parameters directly from checkpoint weights,
falling back to *cfg* (series config) when a key is not found."""
overrides = {}
has_prelude = any(k.startswith("prelude.") for k in state_dict)
has_blocks = any(k.startswith("blocks.") for k in state_dict)
has_recurrent = any(k.startswith("recurrent.") for k in state_dict)
uses_recurrent_arch = has_prelude and has_recurrent and not has_blocks
# dim from embed_tokens.weight [vocab, dim]
if "embed_tokens.weight" in state_dict:
overrides["dim"] = state_dict["embed_tokens.weight"].shape[1]
if uses_recurrent_arch:
if "prelude.0.ffn.gate.weight" in state_dict:
overrides["ffn_dim"] = state_dict["prelude.0.ffn.gate.weight"].shape[0]
overrides["n_unique_layers"] = 0
src = "prelude.0"
else:
if "blocks.0.ffn.gate.weight" in state_dict:
overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0]
block_ids = {
int(k.split(".")[1])
for k in state_dict
if k.startswith("blocks.") and k.split(".")[1].isdigit()
}
if block_ids:
overrides["n_unique_layers"] = max(block_ids) + 1
src = "blocks.0"
dim = overrides.get("dim", int(cfg.get("dim", model_config.dim)))
if f"{src}.attn.wq.weight" in state_dict:
wq_rows = state_dict[f"{src}.attn.wq.weight"].shape[0]
if f"{src}.attn.q_norm.weight" in state_dict:
head_dim = state_dict[f"{src}.attn.q_norm.weight"].shape[0]
overrides["n_heads"] = wq_rows // head_dim
if f"{src}.attn.wk.weight" in state_dict:
wk_rows = state_dict[f"{src}.attn.wk.weight"].shape[0]
if f"{src}.attn.k_norm.weight" in state_dict:
head_dim = state_dict[f"{src}.attn.k_norm.weight"].shape[0]
overrides["n_kv_heads"] = wk_rows // head_dim
# engram params
for key, val in state_dict.items():
if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2:
overrides["engram_table_size"] = val.shape[0]
overrides["engram_dim"] = val.shape[1]
break
engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0)))
engram_max_ngram = int(cfg.get("engram_max_ngram", 2))
if engram_dim > 0:
for key, val in state_dict.items():
if ".engram.branch_conv.weight" in key and val.dim() == 3:
total_branch_dim = val.shape[0]
denom = engram_dim * (engram_max_ngram - 1)
if denom > 0 and total_branch_dim % denom == 0:
overrides["engram_heads"] = total_branch_dim // denom
break
merged = dict(cfg)
merged.update(overrides)
return merged
def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dict:
tokenizer = WordTokenizer.load(tokenizer_path)
ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
cfg = series_config(series)
vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))
state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
cfg = _infer_arch_from_state_dict(state_dict, cfg)
engram_dim = int(cfg.get("engram_dim", 0))
if _detect_engram(state_dict) == 0:
engram_dim = 0
mhc_expansion = _detect_mhc(state_dict)
if mhc_expansion == 1:
mhc_expansion = int(cfg.get("mhc_expansion", 1))
ckpt_sleep_cap, ckpt_sleep_heads = _detect_sleep_gate(state_dict)
sleep_gate_cap = ckpt_sleep_cap if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_cap", 0))
sleep_gate_heads = ckpt_sleep_heads if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_heads", 4))
sleep_retention_enabled = bool(cfg.get("sleep_retention_enabled", True))
sleep_retention_hidden = int(cfg.get("sleep_retention_hidden", 0))
latent_think_layers = _detect_latent_think(state_dict)
if latent_think_layers == 0:
latent_think_layers = int(cfg.get("latent_think_layers", 0))
prelude_layers = _detect_prelude_layers(state_dict)
coda_layers = _detect_coda_layers(state_dict)
recurrent_loops = _detect_recurrent_loops(state_dict)
ckpt_lora_rank = _detect_recurrent_lora_rank(state_dict)
if ckpt_lora_rank > 0:
inferred_series = _infer_series_from_lora_rank(ckpt_lora_rank)
if inferred_series and inferred_series != series.lower():
series = inferred_series.capitalize()
cfg = series_config(series)
recurrent_lora_rank = ckpt_lora_rank
else:
recurrent_lora_rank = int(cfg.get("recurrent_lora_rank", 0))
recurrent_act_threshold = float(cfg.get("recurrent_act_threshold", 0.99))
recurrent_loop_embed_dim = int(cfg.get("recurrent_loop_embed_dim", 0))
n_unique = int(cfg.get("n_unique_layers", model_config.n_unique_layers))
model = TinyMemoryLM(
vocab_size=vocab_size,
dim=int(cfg.get("dim", model_config.dim)),
n_unique_layers=n_unique,
n_logical_layers=int(cfg.get("n_logical_layers", model_config.n_logical_layers)),
n_heads=int(cfg.get("n_heads", model_config.n_heads)),
n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
dropout=float(cfg.get("dropout", model_config.dropout)),
mtp_horizons=tuple(int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)),
grad_checkpoint=False,
sliding_window=int(cfg.get("sliding_window_size", getattr(model_config, "sliding_window_size", 512))),
rope_fraction=float(cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))),
embed_scale=bool(cfg.get("embed_scale", getattr(model_config, "embed_scale", True))),
engram_dim=engram_dim,
engram_heads=int(cfg.get("engram_heads", 4)),
engram_table_size=int(cfg.get("engram_table_size", 8192)),
engram_max_ngram=int(cfg.get("engram_max_ngram", 3)),
mhc_expansion=mhc_expansion,
sleep_gate_cap=sleep_gate_cap,
sleep_gate_heads=sleep_gate_heads,
sleep_retention_enabled=sleep_retention_enabled,
sleep_retention_hidden=sleep_retention_hidden,
latent_think_layers=latent_think_layers,
prelude_layers=prelude_layers,
coda_layers=coda_layers,
recurrent_loops=recurrent_loops,
recurrent_act_threshold=recurrent_act_threshold,
recurrent_lora_rank=recurrent_lora_rank,
recurrent_loop_embed_dim=recurrent_loop_embed_dim,
)
model.load_state_dict(state_dict, strict=False)
model.eval()
if tokenizer.vocab_size > vocab_size:
model.resize_token_embeddings(tokenizer.vocab_size)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
return {
"model": model,
"tokenizer": tokenizer,
"device": device,
"series": series,
"sft_mode": ckpt.get("sft_mode", None),
"phase": ckpt.get("phase", None),
}
# ---------------------------------------------------------------------------
# HuggingFace Model Download & Loading
# ---------------------------------------------------------------------------
def download_huggingface_model(hf_id: str, cache_dir: Path) -> dict:
try:
from huggingface_hub import snapshot_download
except ImportError:
print("huggingface_hub not installed. Install with: pip install huggingface_hub")
sys.exit(1)
print(f"Downloading {hf_id}...")
try:
local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir)))
except Exception as e:
print(f"Failed to download {hf_id}: {e}")
return None
print(f"Using cached {hf_id} from {local_dir}")
# Check common subdirectory names: "models/", "model/"
if (local_dir / "models").exists():
model_dir = local_dir / "models"
elif (local_dir / "model").exists():
model_dir = local_dir / "model"
else:
model_dir = local_dir
model_path = model_dir / "model.pt"
pretrain_path = model_dir / "pretrain.pt"
tokenizer_path = model_dir / "tokenizer.json"
ckpt_path = None
for p in [model_path, pretrain_path]:
if p.exists():
ckpt_path = p
break
if ckpt_path is None or not tokenizer_path.exists():
print(f"Missing model files in {model_dir}")
print(f" model.pt exists: {model_path.exists()}")
print(f" pretrain.pt exists: {pretrain_path.exists()}")
print(f" tokenizer.json exists: {tokenizer_path.exists()}")
return None
return {
"model_path": ckpt_path,
"tokenizer_path": tokenizer_path,
"model_name": ckpt_path.stem,
}
def load_huggingface_model(hf_id: str, cache_dir: Path) -> dict:
files = download_huggingface_model(hf_id, cache_dir)
if files is None:
return None
return load_local_model(files["model_path"], files["tokenizer_path"], "Haiku")
# ---------------------------------------------------------------------------
# Compare All Models
# ---------------------------------------------------------------------------
_hf_model_cache: Dict[str, dict] = {}
def prefetch_huggingface_models() -> None:
root = Path(__file__).resolve().parent
cache_dir = root / "cache" / "huggingface"
cache_dir.mkdir(parents=True, exist_ok=True)
print("Downloading/preparing HuggingFace models...")
for name, hf_id in HUGGINGFACE_MODELS.items():
print(f" {name}...")
bundle = load_huggingface_model(hf_id, cache_dir)
if bundle:
_hf_model_cache[name] = bundle
print(f"Prepared {len(_hf_model_cache)} HuggingFace models")
def compare_all_models(prompt: str, cfg: dict) -> None:
root = Path(__file__).resolve().parent
runs_dir = root / "runs"
all_models = discover_models(runs_dir)
is_pretrain = not cfg.get("sft_mode", True)
local_models = [
m for m in all_models
if ("pretrain" in m["checkpoint"]) == is_pretrain
]
if not local_models:
print("No models found matching mode.")
return
results: List[dict] = []
for m in local_models:
print(f"\n{'='*60}")
print(f"Loading local {m['name']}/{m['checkpoint']}...")
try:
bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
except Exception as e:
print(f"Failed to load {m['name']}: {e}")
continue
model = bundle["model"]
tokenizer = bundle["tokenizer"]
device = bundle["device"]
print(f"Generating on '{prompt}'...")
output = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=cfg["max_new_tokens"],
temperature=cfg["temperature"],
top_k=cfg["top_k"],
top_p=cfg["top_p"],
min_p=cfg["min_p"],
no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
repetition_penalty=cfg["repetition_penalty"],
logit_soft_cap=cfg["logit_soft_cap"],
loop_penalty=cfg["loop_penalty"],
device=str(device),
sft_mode=cfg["sft_mode"],
stream=True,
context_window=cfg["context_window"],
)
results.append({
"name": f"[LOCAL] {m['name']}/{m['checkpoint']}",
"output": output,
"device": device,
})
for name, bundle in _hf_model_cache.items():
print(f"\n{'='*60}")
print(f"Loading {name} (cached)...")
model = bundle["model"]
tokenizer = bundle["tokenizer"]
device = bundle["device"]
print(f"Generating on '{prompt}'...")
output = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=cfg["max_new_tokens"],
temperature=cfg["temperature"],
top_k=cfg["top_k"],
top_p=cfg["top_p"],
min_p=cfg["min_p"],
no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
repetition_penalty=cfg["repetition_penalty"],
logit_soft_cap=cfg["logit_soft_cap"],
loop_penalty=cfg["loop_penalty"],
device=str(device),
sft_mode=cfg["sft_mode"],
stream=True,
context_window=cfg["context_window"],
)
results.append({
"name": name,
"output": output,
"device": device,
})
print(f"\n{'='*60}")
print("=" * 60)
print("SIDE-BY-SIDE COMPARISON")
print("=" * 60)
for r in results:
print(f"\n--- {r['name']} ---")
print(r["output"])
print(f"\n{'='*60}")
# ---------------------------------------------------------------------------
# Benchmark
# ---------------------------------------------------------------------------
BENCHMARKS = {
"blimp": {
"label": "BLiMP",
"desc": "Grammaticality minimal pairs (67 paradigms). Accuracy = % grammatical < ungrammatical perplexity.",
"hf_dataset": ("nyu-mll/blimp", None),
"metric": "accuracy",
},
"wikitext2": {
"label": "WikiText-2",
"desc": "LM perplexity on Wikipedia test split. Lower is better.",
"hf_dataset": ("Salesforce/wikitext", "wikitext-2-raw-v1"),
"metric": "perplexity",
},
"arc_easy": {
"label": "ARC-Easy",
"desc": "Multiple-choice science QA (~2.4K). Perplexity-based answer selection.",
"hf_dataset": ("allenai/ai2_arc", "ARC-Easy"),
"metric": "accuracy",
},
}
def _score_text(model: TinyMemoryLM, tokenizer: WordTokenizer, text: str, device: str) -> float:
ids = tokenizer.encode(text, add_bos=True, add_eos=False)
if len(ids) < 2:
return float("nan")
ids_t = torch.tensor([ids], dtype=torch.long, device=device)
with torch.no_grad():
logits, *_ = model(ids_t)
log_probs = F.log_softmax(logits[0], dim=-1)
targets = ids_t[0, 1:]
nll = -log_probs[range(len(targets)), targets].mean().item()
return nll
def _score_completion(model: TinyMemoryLM, tokenizer: WordTokenizer, context: str, completion: str, device: str) -> float:
full_ids = tokenizer.encode(context + completion, add_bos=True, add_eos=False)
ctx_ids = tokenizer.encode(context, add_bos=True, add_eos=False)
n_ctx = len(ctx_ids)
n_ref = len(full_ids) - n_ctx
if n_ref <= 0:
return float("nan")
ids_t = torch.tensor([full_ids], dtype=torch.long, device=device)
with torch.no_grad():
logits, *_ = model(ids_t)
log_probs = F.log_softmax(logits[0], dim=-1)
targets = ids_t[0, 1:]
ref_start = n_ctx - 1
ref_end = min(ref_start + n_ref, log_probs.shape[0])
if ref_start >= ref_end:
return float("nan")
nll = -log_probs[ref_start:ref_end][range(ref_end - ref_start), targets[ref_start:ref_end]].mean().item()
return nll
BLIMP_PARADIGMS = [
"adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement",
"animate_subject_passive", "animate_subject_trans", "causative",
"complex_NP_island", "coordinate_structure_constraint_complex_left_branch",
"coordinate_structure_constraint_object_extraction",
"determiner_noun_agreement_1", "determiner_noun_agreement_2",
"determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2",
"determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1",
"determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1",
"distractor_agreement_relational_noun", "distractor_agreement_relative_clause",
"drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2",
"existential_there_object_raising", "existential_there_quantifiers_1",
"existential_there_quantifiers_2", "existential_there_subject_raising",
"expletive_it_object_raising", "inchoative", "intransitive",
"irregular_past_participle_adjectives", "irregular_past_participle_verbs",
"irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2",
"left_branch_island_echo_question", "left_branch_island_simple_question",
"matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2",
"only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2",
"principle_A_c_command", "principle_A_case_1", "principle_A_case_2",
"principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3",
"principle_A_reconstruction", "regular_plural_subject_verb_agreement_1",
"regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present",
"sentential_negation_npi_scope", "sentential_subject_island",
"superlative_quantifiers_1", "superlative_quantifiers_2",
"tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island",
"wh_questions_object_gap", "wh_questions_subject_gap",
"wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap",
"wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap",
"wh_vs_that_with_gap_long_distance",
]
def _run_blimp(model, tokenizer, device, n_samples: int = 200) -> Tuple[List[str], List[float]]:
from datasets import load_dataset # type: ignore
accuracies: List[float] = []
for paradigm in BLIMP_PARADIGMS:
try:
ds = load_dataset("nyu-mll/blimp", paradigm, split="train")
except Exception as e:
print(f" {paradigm}: skip ({e})")
accuracies.append(float("nan"))
continue
items = list(ds)[:n_samples]
correct = 0
for ex in items:
good_nll = _score_text(model, tokenizer, ex["sentence_good"], device)
bad_nll = _score_text(model, tokenizer, ex["sentence_bad"], device)
if math.isnan(good_nll) or math.isnan(bad_nll):
continue
if good_nll < bad_nll:
correct += 1
acc = correct / len(items) if items else float("nan")
accuracies.append(acc)
print(f" {paradigm:50s} acc={acc:.3f}")
return BLIMP_PARADIGMS, accuracies
def _run_wikitext2(model, tokenizer, device, chunk_chars: int = 512, max_chunks: int = 100) -> Tuple[List[str], List[float]]:
from datasets import load_dataset # type: ignore
ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")
full_text = "\n".join(ex["text"] for ex in ds if ex["text"].strip())
chunks = [full_text[i:i + chunk_chars] for i in range(0, len(full_text), chunk_chars)]
chunks = [c for c in chunks if len(c) > 20][:max_chunks]
labels: List[str] = []
ppls: List[float] = []
for i, chunk in enumerate(chunks):
nll = _score_text(model, tokenizer, chunk, device)
ppl = math.exp(nll) if not math.isnan(nll) else float("nan")
labels.append(f"chunk {i + 1}")
ppls.append(ppl)
if (i + 1) % 10 == 0:
valid = [v for v in ppls if not math.isnan(v)]
mean = sum(valid) / len(valid) if valid else float("nan")
print(f" chunk {i + 1}/{len(chunks)} running mean ppl={mean:.2f}")
return labels, ppls
def _run_arc_easy(model, tokenizer, device, max_samples: int = 200) -> Tuple[List[str], List[float]]:
from datasets import load_dataset # type: ignore
ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")
items = list(ds)[:max_samples]
labels: List[str] = []
scores: List[float] = []
for i, ex in enumerate(items):
question = ex["question"]
choices = ex["choices"]["text"]
choice_labels = ex["choices"]["label"]
answer_key = ex["answerKey"]
context = f"Question: {question}\nAnswer:"
nlls = [_score_completion(model, tokenizer, context, f" {c}", device) for c in choices]
if all(math.isnan(v) for v in nlls):
scores.append(float("nan"))
else:
best_idx = min(range(len(nlls)), key=lambda j: nlls[j] if not math.isnan(nlls[j]) else float("inf"))
predicted = choice_labels[best_idx]
scores.append(1.0 if predicted == answer_key else 0.0)
labels.append(f"Q{i + 1}")
n_valid = sum(1 for s in scores if not math.isnan(s))
acc = sum(s for s in scores if not math.isnan(s)) / n_valid if n_valid else float("nan")
print(f" {n_valid} questions evaluated, accuracy={acc:.3f}")
return labels, scores
def run_benchmark_mode() -> None:
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except ImportError:
print("matplotlib not installed. pip install matplotlib")
return
bench_keys = list(BENCHMARKS.keys())
print("\nBenchmarks:")
for i, k in enumerate(bench_keys):
b = BENCHMARKS[k]
print(f" [{i + 1}] {b['label']} β€” {b['desc']}")
print("Select benchmark [1]:", end=" ", flush=True)
try:
b_choice = input().strip() or "1"
except (EOFError, KeyboardInterrupt):
print()
return
if not (b_choice.isdigit() and 1 <= int(b_choice) <= len(bench_keys)):
print("Invalid selection.")
return
bench_key = bench_keys[int(b_choice) - 1]
bench = BENCHMARKS[bench_key]
print(f"Benchmark: {bench['label']}")
root = Path(__file__).resolve().parent
runs_dir = root / "runs"
all_models = discover_models(runs_dir)
model_entries: List[dict] = []
for m in all_models:
model_entries.append({"label": f"[LOCAL] {m['name']}/{m['checkpoint']}", "type": "local", "meta": m})
for hf_name, hf_id in HUGGINGFACE_MODELS.items():
model_entries.append({"label": f"[HF] {hf_name}", "type": "hf", "hf_id": hf_id, "hf_name": hf_name})
if not model_entries:
print("No models found.")
return
print("\nAvailable models:")
for i, e in enumerate(model_entries):
print(f" [{i + 1}] {e['label']}")
print(" [a] All models")
print("Select models (comma-separated or 'a'):", end=" ", flush=True)
try:
raw = input().strip()
except (EOFError, KeyboardInterrupt):
print()
return
if raw.lower() == "a":
selected = list(range(len(model_entries)))
else:
selected = []
for tok in raw.split(","):
tok = tok.strip()
if tok.isdigit() and 1 <= int(tok) <= len(model_entries):
selected.append(int(tok) - 1)
if not selected:
print("No valid selection.")
return
all_results: List[dict] = []
shared_x_labels: Optional[List[str]] = None
for idx in selected:
entry = model_entries[idx]
print(f"\n{'='*60}\nLoading {entry['label']}...")
try:
if entry["type"] == "local":
m = entry["meta"]
bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
else:
bundle = load_huggingface_model(entry["hf_id"], root / ".hf_cache")
except Exception as e:
print(f" Failed: {e}")
continue
model = bundle["model"]
tokenizer = bundle["tokenizer"]
device = str(bundle["device"])
model.eval()
if bench_key == "blimp":
x_labels, y_vals = _run_blimp(model, tokenizer, device)
elif bench_key == "wikitext2":
x_labels, y_vals = _run_wikitext2(model, tokenizer, device)
else:
x_labels, y_vals = _run_arc_easy(model, tokenizer, device)
if shared_x_labels is None:
shared_x_labels = x_labels
valid = [v for v in y_vals if not math.isnan(v)]
summary = sum(valid) / len(valid) if valid else float("nan")
all_results.append({"label": entry["label"], "y": y_vals, "summary": summary})
if not all_results or shared_x_labels is None:
print("No results to plot.")
return
metric = bench["metric"]
paired = sorted(zip([r["summary"] for r in all_results], [r["label"] for r in all_results]),
reverse=(metric != "perplexity"))
summaries, model_labels = zip(*paired) if paired else ([], [])
n = len(summaries)
colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)]
fig, ax = plt.subplots(figsize=(max(6, n * 1.4), 6))
bars = ax.bar(range(n), summaries, color=colors, edgecolor="black")
for bar, val in zip(bars, summaries):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold")
ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)"
ax.set_ylabel(ylabel)
ax.set_title(f"{bench['label']} Benchmark β€” Model Comparison")
ax.set_xticks(range(n))
ax.set_xticklabels(model_labels, rotation=20, ha="right", fontsize=9)
if metric == "accuracy":
ax.set_ylim(0, 1.05)
ax.grid(True, axis="y", alpha=0.3)
plt.tight_layout()
out_path = root / f"benchmark_{bench_key}.png"
plt.savefig(str(out_path), dpi=150)
print(f"\nChart saved to {out_path}")
try:
import subprocess
subprocess.Popen(["xdg-open", str(out_path)])
except Exception:
pass
# ---------------------------------------------------------------------------
# Interactive CLI
# ---------------------------------------------------------------------------
def _pick_series(detected: str) -> str:
series_list = list(MODEL_SERIES.keys())
detected_lower = detected.lower()
default_idx = next(
(i + 1 for i, s in enumerate(series_list) if s == detected_lower), 1
)
# Skip selection if only one series available
if len(series_list) == 1:
return series_list[0].capitalize()
print("Series:")
for i, s in enumerate(series_list):
marker = " (detected)" if s == detected_lower else ""
print(f" [{i + 1}] {s.capitalize()}{marker}")
while True:
try:
choice = input(f"Select series [{default_idx}]: ").strip()
except (EOFError, KeyboardInterrupt):
print()
sys.exit(0)
if not choice:
choice = str(default_idx)
if choice.isdigit() and 1 <= int(choice) <= len(series_list):
return series_list[int(choice) - 1].capitalize()
print(f"Enter a number 1-{len(series_list)}")
def pick_model(runs_dir: Path) -> tuple[dict, str]:
models = discover_models(runs_dir)
if not models:
print(f"No models found in {runs_dir}")
print("Expected layout: runs/<name>/model.pt (or pretrain.pt) + tokenizer.json")
sys.exit(1)
if len(models) == 1:
m = models[0]
print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...")
bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
return bundle, m["checkpoint"]
print("Available models:")
for i, m in enumerate(models):
print(f" [{i + 1}] {m['name']}/{m['checkpoint']} ({m['series']})")
while True:
try:
choice = input("Select model [1]: ").strip()
except (EOFError, KeyboardInterrupt):
print()
sys.exit(0)
if not choice:
choice = "1"
if choice.isdigit() and 1 <= int(choice) <= len(models):
m = models[int(choice) - 1]
print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...")
bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"])
return bundle, m["checkpoint"]
print(f"Enter a number 1-{len(models)}")
# ---------------------------------------------------------------------------
# Generation mode configs
# ---------------------------------------------------------------------------
MODES = {
"chat-coherent": {
"label": "Chat β€” Coherent",
"desc": "structured, consistent, strong repetition control",
"sft_mode": "chat",
"temperature": 0.35,
"top_k": 20,
"top_p": 0.88,
"min_p": 0.10,
"no_repeat_ngram_size": 4,
"repetition_penalty": 1.22,
"logit_soft_cap": 20.0,
"loop_penalty": 20.0,
"max_new_tokens": 4096,
"context_window": 2048,
},
"chat-variants": {
"label": "Chat β€” Variants",
"desc": "creative, diverse, more surprising outputs",
"sft_mode": "chat",
"temperature": 0.65,
"top_k": 60,
"top_p": 0.92,
"min_p": 0.05,
"no_repeat_ngram_size": 4,
"repetition_penalty": 1.12,
"logit_soft_cap": 20.0,
"loop_penalty": 14.0,
"max_new_tokens": 4096,
"context_window": 2048,
},
"pretrain-coherent": {
"label": "Pretrain β€” Coherent",
"desc": "grounded continuation, low temperature, tight sampling",
"sft_mode": False,
"temperature": 0.3,
"top_k": 20,
"top_p": 0.85,
"min_p": 0.10,
"no_repeat_ngram_size": 4,
"repetition_penalty": 1.2,
"logit_soft_cap": 20.0,
"loop_penalty": 20.0,
"max_new_tokens": 4096,
"context_window": 2048,
},
"pretrain-variants": {
"label": "Pretrain β€” Variants",
"desc": "free-form continuation, higher temperature, more exploration",
"sft_mode": False,
"temperature": 0.7,
"top_k": 60,
"top_p": 0.93,
"min_p": 0.04,
"no_repeat_ngram_size": 4,
"repetition_penalty": 1.12,
"logit_soft_cap": 20.0,
"loop_penalty": 12.0,
"max_new_tokens": 4096,
"context_window": 2048,
},
}
_MODE_LIST = list(MODES.keys())
def pick_mode(is_pretrain: bool) -> dict:
"""Prompt the user to choose a generation mode. Returns a config dict."""
# Filter to relevant modes based on checkpoint type
candidates = [k for k in _MODE_LIST if ("pretrain" in k) == is_pretrain]
print("\nGeneration mode:")
for i, key in enumerate(candidates):
cfg = MODES[key]
print(f" [{i + 1}] {cfg['label']} β€” {cfg['desc']}")
while True:
try:
choice = input("Select mode [1]: ").strip()
except (EOFError, KeyboardInterrupt):
print()
sys.exit(0)
if not choice:
choice = "1"
if choice.isdigit() and 1 <= int(choice) <= len(candidates):
key = candidates[int(choice) - 1]
cfg = MODES[key]
print(f"Mode: {cfg['label']}")
return cfg
print(f"Enter a number 1-{len(candidates)}")
def _run_loop(bundle: dict, cfg: dict) -> None:
model = bundle["model"]
tokenizer = bundle["tokenizer"]
device = bundle["device"]
sft = cfg["sft_mode"]
prompt_label = "You" if sft else "Prompt"
print(f"\nModel ready on {device}. Type your message, or /quit to exit.")
print(f" temp={cfg['temperature']} top_k={cfg['top_k']} top_p={cfg['top_p']}")
print(f" min_p={cfg['min_p']} ng={cfg['no_repeat_ngram_size']} rp={cfg['repetition_penalty']}")
print(f" cap={cfg['logit_soft_cap']} loop_penalty={cfg['loop_penalty']}\n")
while True:
try:
prompt = input(f"{prompt_label}: ").strip()
except (EOFError, KeyboardInterrupt):
print()
break
if not prompt:
continue
if prompt in ("/quit", "/exit", "/q"):
break
if prompt == "/help":
print("Commands: /quit /exit /q /help /mode")
if sft:
print("Anything else is sent as a chat prompt.")
else:
print("Anything else is sent as a raw continuation prompt.")
continue
if prompt == "/mode":
print(f"Current: {cfg['label']} β€” {cfg['desc']}")
continue
print("AI: ", end="", flush=True)
generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=cfg["max_new_tokens"],
temperature=cfg["temperature"],
top_k=cfg["top_k"],
top_p=cfg["top_p"],
min_p=cfg["min_p"],
no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
repetition_penalty=cfg["repetition_penalty"],
logit_soft_cap=cfg["logit_soft_cap"],
loop_penalty=cfg["loop_penalty"],
device=str(device),
sft_mode=cfg["sft_mode"],
stream=True,
context_window=cfg["context_window"],
)
# ---------------------------------------------------------------------------
# Dynamic collection discovery
# ---------------------------------------------------------------------------
_COLLECTION_SLUG = "CompactAI-O/tmlm-haiku-series"
_AUTHOR = "CompactAI-O"
_SEARCH = "TMLM-Haiku"
_FALLBACK_COLLECTION = [
{"version": "TMLM-Haiku-2.3", "hf_id": "CompactAI-O/TMLM-Haiku-2.3"},
{"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"},
{"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"},
{"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"},
{"version": "Glint-1", "hf_id": "CompactAI-O/Glint-1"},
]
_EXTRA_REPOS = ["CompactAI-O/Glint-1"]
def _probe_repo(hf_id: str) -> dict | None:
"""Return entry dict for one repo, or None if no usable checkpoints found."""
from huggingface_hub import list_repo_files
try:
files = set(list_repo_files(hf_id))
except Exception:
return None
# Detect which subdirectory holds the checkpoints
subdir: str | None = None
for candidate in ("models", "model"):
if any(f.startswith(f"{candidate}/") for f in files):
subdir = candidate
break
prefix = f"{subdir}/" if subdir else ""
# Collect all .pt files in the checkpoint directory
pt_files = sorted(
f[len(prefix):] for f in files
if f.startswith(prefix) and f.endswith(".pt")
)
_LABELS = {
"model.pt": ("Chat (SFT)", False),
"model_rep.pt": ("Chat (anti-repetition)", False),
"pretrain.pt": ("Pretrain (base)", True),
}
checkpoints = []
for fname in pt_files:
label, is_pretrain = _LABELS.get(fname, (fname.removesuffix(".pt"), "pretrain" in fname))
checkpoints.append((label, fname, is_pretrain))
if not checkpoints:
return None
return {
"version": hf_id.split("/")[-1],
"hf_id": hf_id,
"subdir": subdir,
"checkpoints": checkpoints,
"desc": "",
}
def fetch_collection() -> list[dict]:
"""Query HF for all CompactAI-O TMLM-Haiku models, newest first."""
from huggingface_hub import HfApi
print("Checking HuggingFace collection for available models...")
try:
api = HfApi()
infos = list(
api.list_models(
author=_AUTHOR,
search=_SEARCH,
sort="lastModified",
)
)
infos.sort(key=lambda m: getattr(m, "lastModified", ""), reverse=True)
except Exception as exc:
print(f" Could not reach HuggingFace ({exc}); using fallback list.")
infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION]
entries = []
seen_ids: set = set()
for info in infos:
repo_id = info.id
if _SEARCH.lower() not in repo_id.lower():
continue
entry = _probe_repo(repo_id)
if entry:
entries.append(entry)
seen_ids.add(repo_id)
# Always include extra repos (e.g. Glint-1) not caught by TMLM-Haiku search
for repo_id in _EXTRA_REPOS:
if repo_id not in seen_ids:
entry = _probe_repo(repo_id)
if entry:
entries.append(entry)
seen_ids.add(repo_id)
if not entries:
print(" No models found; using fallback list.")
for fb in _FALLBACK_COLLECTION:
e = _probe_repo(fb["hf_id"])
if e:
entries.append(e)
return entries
# ---------------------------------------------------------------------------
# Download helper
# ---------------------------------------------------------------------------
def _download_version(entry: dict, cache_dir: Path) -> Path:
"""Download full repo snapshot; return the directory containing model files."""
try:
from huggingface_hub import snapshot_download
except ImportError:
print("huggingface_hub not installed. Run: pip install huggingface_hub")
sys.exit(1)
hf_id = entry["hf_id"]
print(f"Fetching {hf_id} ...")
try:
local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir)))
except Exception as exc:
print(f"Download failed: {exc}")
sys.exit(1)
subdir = entry.get("subdir")
model_dir = (local_dir / subdir) if subdir else local_dir
if not model_dir.exists():
# Fallback to root
model_dir = local_dir
return model_dir
# ---------------------------------------------------------------------------
# Selection prompts
# ---------------------------------------------------------------------------
def _prompt_int(prompt: str, lo: int, hi: int, default: int = 1) -> int:
while True:
try:
raw = input(f"{prompt} [{default}]: ").strip()
except (EOFError, KeyboardInterrupt):
print()
sys.exit(0)
if not raw:
return default
if raw.isdigit() and lo <= int(raw) <= hi:
return int(raw)
print(f" Enter a number {lo}–{hi}.")
def pick_version(collection: list[dict]) -> dict:
print("\nTMLM-Haiku series (CompactAI-O)\n")
for i, entry in enumerate(collection):
desc = f" β€” {entry['desc']}" if entry["desc"] else ""
print(f" [{i + 1}] {entry['version']}{desc}")
idx = _prompt_int("Select version", 1, len(collection))
return collection[idx - 1]
def pick_checkpoint(entry: dict) -> tuple[str, bool]:
"""Return (filename, is_pretrain)."""
ckpts = entry["checkpoints"]
if len(ckpts) == 1:
label, fname, is_pretrain = ckpts[0]
print(f" Using: {label} ({fname})")
return fname, is_pretrain
print(f"\nCheckpoints for {entry['version']}:")
for i, (label, fname, _) in enumerate(ckpts):
print(f" [{i + 1}] {label} ({fname})")
idx = _prompt_int("Select checkpoint", 1, len(ckpts))
label, fname, is_pretrain = ckpts[idx - 1]
return fname, is_pretrain
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Gradio Space
# ---------------------------------------------------------------------------
import gradio as gr
_CACHE_DIR = Path(__file__).parent / ".hf_cache"
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
_collection_cache: list = []
_model_cache: dict = {}
def _get_collection() -> list:
global _collection_cache
if not _collection_cache:
try:
_collection_cache = fetch_collection()
except Exception as e:
print(f"Warning: fetch_collection failed ({e}); using fallback.")
_collection_cache = [
_probe_repo(e["hf_id"]) or {"version": e["version"], "hf_id": e["hf_id"],
"subdir": None, "checkpoints": [("Chat (SFT)", "model.pt", False)], "desc": ""}
for e in _FALLBACK_COLLECTION
]
return _collection_cache
def _collection_versions() -> list[str]:
return [e["version"] for e in _get_collection()]
def _checkpoints_for(version: str) -> list[tuple[str, str, bool]]:
for e in _get_collection():
if e["version"] == version:
return e["checkpoints"]
return []
def _ckpt_labels(version: str) -> list[str]:
return [label for label, _, _ in _checkpoints_for(version)]
def _ckpt_is_pretrain(version: str, label: str) -> bool:
for lbl, _, is_pt in _checkpoints_for(version):
if lbl == label:
return is_pt
return False
def _ckpt_fname(version: str, label: str) -> str:
for lbl, fname, _ in _checkpoints_for(version):
if lbl == label:
return fname
return "model.pt"
def _load_bundle(version: str, ckpt_label: str) -> dict:
key = f"{version}/{ckpt_label}"
if key not in _model_cache:
fname = _ckpt_fname(version, ckpt_label)
for entry in _get_collection():
if entry["version"] == version:
model_dir = _download_version(entry, _CACHE_DIR)
model_path = model_dir / fname
tokenizer_path = model_dir / "tokenizer.json"
_model_cache[key] = load_local_model(model_path, tokenizer_path, "Haiku")
break
return _model_cache[key]
def _build_conversation_prompt(history: list[dict], new_message: str) -> str:
"""Flatten Gradio messages-format history + new turn into a raw prompt."""
parts = []
# history is [{role, content}, ...] pairs already in order
i = 0
while i < len(history) - 1:
u = history[i]
a = history[i + 1]
if u["role"] == "user" and a["role"] == "assistant":
parts.append(f"<|user|>\n{u['content']}\n<|assistant|>\n{a['content']}")
i += 2
parts.append(f"<|user|>\n{new_message}\n<|assistant|>\n")
return "".join(parts)
# ---- chat ----
def _on_version_change(version):
labels = _ckpt_labels(version)
return gr.update(choices=labels, value=labels[0] if labels else None)
def _chat_submit(message, history):
history = history or []
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": ""})
return "", history
def _chat_stream(history, version, ckpt_label, mode_key, use_custom,
temperature, top_k, top_p, min_p, rep_penalty,
ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode):
if not history or history[-1]["role"] != "assistant":
yield history
return
try:
bundle = _load_bundle(version, ckpt_label)
except Exception as e:
history[-1]["content"] = f"[Error loading model: {e}]"
yield history
return
prior_msgs = history[:-2] # exclude the current user+empty-assistant pair
new_msg = history[-2]["content"]
if use_custom:
cfg = {
"sft_mode": not raw_mode,
"temperature": temperature, "top_k": top_k, "top_p": top_p,
"min_p": min_p, "repetition_penalty": rep_penalty,
"no_repeat_ngram_size": ngram_size, "logit_soft_cap": soft_cap,
"loop_penalty": loop_pen, "max_new_tokens": max_tokens,
"context_window": ctx_win,
}
else:
cfg = dict(MODES[mode_key])
# Max new tokens slider always applies (independent of preset override)
cfg["max_new_tokens"] = int(max_tokens)
cfg["context_window"] = int(ctx_win)
if prior_msgs:
prompt = _build_conversation_prompt(prior_msgs, new_msg)
sft = False
else:
prompt = new_msg
sft = cfg["sft_mode"]
for partial in generate_stream(
model=bundle["model"], tokenizer=bundle["tokenizer"],
prompt=prompt, device=str(bundle["device"]),
sft_mode=sft,
temperature=cfg["temperature"], top_k=cfg["top_k"],
top_p=cfg["top_p"], min_p=cfg["min_p"],
repetition_penalty=cfg["repetition_penalty"],
no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
logit_soft_cap=cfg["logit_soft_cap"],
loop_penalty=cfg["loop_penalty"],
max_new_tokens=cfg["max_new_tokens"],
context_window=cfg["context_window"],
):
history[-1]["content"] = partial
yield history
# ---- compare ----
def _compare_fn(prompt, selected_versions, mode_key, use_custom,
temperature, top_k, top_p, min_p, rep_penalty,
ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode):
if use_custom:
cfg = {
"sft_mode": not raw_mode,
"temperature": temperature, "top_k": top_k, "top_p": top_p,
"min_p": min_p, "repetition_penalty": rep_penalty,
"no_repeat_ngram_size": ngram_size, "logit_soft_cap": soft_cap,
"loop_penalty": loop_pen, "max_new_tokens": max_tokens,
"context_window": ctx_win,
}
else:
cfg = dict(MODES[mode_key])
# Max new tokens slider always applies (independent of preset override)
cfg["max_new_tokens"] = int(max_tokens)
cfg["context_window"] = int(ctx_win)
# Iterate + emit oldest β†’ newest (Haiku-1 first, Glint-1 last) so the order
# matches the output-box layout in the UI.
all_versions = _sort_oldest_to_newest(_collection_versions())
selected = set(selected_versions or [])
state = {v: ("⏳ Queued…" if v in selected else "") for v in all_versions}
def _emit():
return [state[v] for v in all_versions]
yield _emit()
for version in all_versions:
if version not in selected:
continue
labels = _ckpt_labels(version)
ckpt_label = labels[0] if labels else None
if not ckpt_label:
state[version] = "[No checkpoint found]"
yield _emit()
continue
state[version] = "⏳ Loading…"
yield _emit()
try:
bundle = _load_bundle(version, ckpt_label)
except Exception as e:
state[version] = f"[Load error: {e}]"
yield _emit()
continue
state[version] = ""
yield _emit()
try:
for partial in generate_stream(
model=bundle["model"], tokenizer=bundle["tokenizer"],
prompt=prompt, device=str(bundle["device"]),
sft_mode=cfg["sft_mode"],
temperature=cfg["temperature"], top_k=cfg["top_k"],
top_p=cfg["top_p"], min_p=cfg["min_p"],
repetition_penalty=cfg["repetition_penalty"],
no_repeat_ngram_size=cfg["no_repeat_ngram_size"],
logit_soft_cap=cfg["logit_soft_cap"],
loop_penalty=cfg["loop_penalty"],
max_new_tokens=cfg["max_new_tokens"],
context_window=cfg["context_window"],
):
state[version] = partial
yield _emit()
except Exception as e:
state[version] = f"[Generation error: {e}]"
yield _emit()
# ---- benchmark ----
def _benchmark_fn(bench_key, selected_versions, max_samples,
progress=gr.Progress(track_tqdm=True)):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
if not selected_versions:
return "No models selected.", None
bench = BENCHMARKS[bench_key]
all_results = []
log_lines = [f"Benchmark: {bench['label']}", ""]
for version in progress.tqdm(selected_versions, desc="Benchmarking"):
log_lines.append(f"--- {version} ---")
labels = _ckpt_labels(version)
ckpt_label = labels[0] if labels else None
if not ckpt_label:
log_lines.append(" (no checkpoint)")
continue
try:
bundle = _load_bundle(version, ckpt_label)
model, tokenizer, device = bundle["model"], bundle["tokenizer"], str(bundle["device"])
model.eval()
if bench_key == "blimp":
_, y = _run_blimp(model, tokenizer, device, n_samples=max_samples)
elif bench_key == "wikitext2":
_, y = _run_wikitext2(model, tokenizer, device, max_chunks=max_samples)
else:
_, y = _run_arc_easy(model, tokenizer, device, max_samples=max_samples)
valid = [v for v in y if not math.isnan(v)]
summary = sum(valid) / len(valid) if valid else float("nan")
all_results.append({"label": version, "summary": summary})
log_lines.append(f" score: {summary:.4f}")
except Exception as e:
log_lines.append(f" error: {e}")
if not all_results:
return "\n".join(log_lines), None
metric = bench["metric"]
paired = sorted(
zip([r["summary"] for r in all_results], [r["label"] for r in all_results]),
reverse=(metric != "perplexity"),
)
summaries, labels_ = zip(*paired)
n = len(summaries)
colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)]
fig, ax = plt.subplots(figsize=(max(6, n * 1.6), 5))
bars = ax.bar(range(n), summaries, color=colors, edgecolor="black")
for bar, val in zip(bars, summaries):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold")
ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)"
ax.set_ylabel(ylabel)
ax.set_title(f"{bench['label']} β€” Model Comparison")
ax.set_xticks(range(n))
ax.set_xticklabels(labels_, rotation=20, ha="right", fontsize=9)
if metric == "accuracy":
ax.set_ylim(0, 1.05)
ax.grid(True, axis="y", alpha=0.3)
plt.tight_layout()
out_path = "/tmp/benchmark_result.png"
plt.savefig(out_path, dpi=150)
plt.close(fig)
log_lines += ["", "Done."]
return "\n".join(log_lines), out_path
# ---- shared advanced params ----
def _advanced_block():
with gr.Accordion("Advanced parameters", open=False):
use_custom = gr.Checkbox(label="Override preset with custom values below", value=False)
raw_mode = gr.Checkbox(label="Raw / pretrain mode (no <|user|> wrapping)", value=False)
with gr.Row():
temperature = gr.Slider(0.0, 2.0, value=0.5, step=0.01, label="Temperature")
top_k = gr.Slider(0, 200, value=20, step=1, label="Top-k")
with gr.Row():
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
min_p = gr.Slider(0.0, 0.5, value=0.05, step=0.01, label="Min-p")
with gr.Row():
rep_penalty = gr.Slider(1.0, 2.0, value=1.15, step=0.01, label="Repetition penalty")
ngram_size = gr.Slider(0, 8, value=4, step=1, label="No-repeat n-gram size")
with gr.Row():
soft_cap = gr.Slider(0.0, 50.0, value=20.0, step=0.5, label="Logit soft cap")
loop_pen = gr.Slider(0.0, 50.0, value=15.0, step=0.5, label="Loop penalty")
with gr.Row():
max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max new tokens")
ctx_win = gr.Slider(128, 4096, value=2048, step=128, label="Context window")
return use_custom, temperature, top_k, top_p, min_p, rep_penalty, ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode
# ---- build UI ----
def _sort_oldest_to_newest(versions: list[str]) -> list[str]:
"""Sort versions oldest→newest using HUGGINGFACE_MODELS key order."""
order = {name: i for i, name in enumerate(HUGGINGFACE_MODELS)}
return sorted(
versions,
key=lambda v: (order.get(v, len(order)), versions.index(v)),
)
_initial_versions = _sort_oldest_to_newest(_collection_versions())
_initial_version = _initial_versions[0] if _initial_versions else None
_initial_ckpt_labels = _ckpt_labels(_initial_version) if _initial_version else []
_mode_keys = list(MODES.keys())
# Hugging Face style theme β€” yellow primary + warm slate neutrals.
_HF_THEME = gr.themes.Default(
primary_hue=gr.themes.Color(
c50="#FFFBEA",
c100="#FFF3C4",
c200="#FCE588",
c300="#FADB5F",
c400="#F7C948",
c500="#FFD21E", # HF brand yellow
c600="#F0B429",
c700="#CB6E17",
c800="#B44D12",
c900="#8D2B0B",
c950="#5C1A04",
),
secondary_hue="orange",
neutral_hue="slate",
font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace"],
).set(
body_background_fill="#FFFFFF",
body_background_fill_dark="#0B0F19",
background_fill_primary="#FFFFFF",
background_fill_primary_dark="#0B0F19",
background_fill_secondary="#F5F6F8",
background_fill_secondary_dark="#1B1F2A",
block_background_fill="#FFFFFF",
block_background_fill_dark="#11151F",
block_border_color="#E5E7EB",
block_border_color_dark="#22273A",
block_label_background_fill="#FFFBEA",
block_label_background_fill_dark="#1B1F2A",
block_label_text_color="#5C1A04",
block_label_text_color_dark="#FFD21E",
block_title_text_color="#0B0F19",
block_title_text_color_dark="#F5F6F8",
button_primary_background_fill="#FFD21E",
button_primary_background_fill_hover="#F0B429",
button_primary_text_color="#0B0F19",
button_primary_text_color_hover="#0B0F19",
button_secondary_background_fill="#F5F6F8",
button_secondary_background_fill_dark="#1B1F2A",
button_secondary_text_color="#0B0F19",
button_secondary_text_color_dark="#F5F6F8",
border_color_accent="#FFD21E",
border_color_primary="#E5E7EB",
border_color_primary_dark="#22273A",
color_accent_soft="#FFFBEA",
color_accent_soft_dark="#1B1F2A",
)
with gr.Blocks(title="CompactAI Models", theme=_HF_THEME) as demo:
gr.Markdown(
"# CompactAI β€” TinyMemoryLM\n"
"Tiny recurrent-depth language models from [CompactAI-O](https://huggingface.co/CompactAI-O)."
)
# ── Chat ──────────────────────────────────────────────────────────────────
with gr.Tab("Chat"):
with gr.Row():
with gr.Column(scale=1, min_width=240):
chat_version = gr.Dropdown(
choices=_initial_versions,
value=_initial_version,
label="Model version",
)
chat_ckpt = gr.Dropdown(
choices=_initial_ckpt_labels,
value=_initial_ckpt_labels[0] if _initial_ckpt_labels else None,
label="Checkpoint",
)
chat_mode = gr.Radio(
choices=_mode_keys,
value="chat-coherent",
label="Mode preset",
info="Ignored when 'Override preset' is checked.",
)
c_use_custom, c_temp, c_topk, c_topp, c_minp, c_rep, c_ng, c_cap, c_lp, c_maxt, c_ctx, c_raw = _advanced_block()
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Conversation", height=500)
with gr.Row():
msg_box = gr.Textbox(placeholder="Type a message…", show_label=False, scale=5)
send_btn = gr.Button("Send", variant="primary", scale=1)
clear_btn = gr.Button("Clear")
chat_version.change(_on_version_change, chat_version, chat_ckpt)
_chat_adv = [chat_version, chat_ckpt, chat_mode,
c_use_custom, c_temp, c_topk, c_topp, c_minp,
c_rep, c_ng, c_cap, c_lp, c_maxt, c_ctx, c_raw]
msg_box.submit(_chat_submit, [msg_box, chatbot], [msg_box, chatbot], queue=False).then(
_chat_stream, [chatbot] + _chat_adv, chatbot
)
send_btn.click(_chat_submit, [msg_box, chatbot], [msg_box, chatbot], queue=False).then(
_chat_stream, [chatbot] + _chat_adv, chatbot
)
clear_btn.click(lambda: [], None, chatbot, queue=False)
# ── Compare ───────────────────────────────────────────────────────────────
with gr.Tab("Compare All Models"):
gr.Markdown("Run the same prompt on every selected model. Outputs stream live one model at a time.")
with gr.Row():
cmp_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt…", lines=4, scale=3)
with gr.Column(scale=1):
cmp_models = gr.CheckboxGroup(
choices=_initial_versions, value=_initial_versions, label="Models to run"
)
cmp_mode = gr.Dropdown(
choices=_mode_keys, value="chat-coherent", label="Mode preset"
)
cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw = _advanced_block()
cmp_run = gr.Button("β–Ά Run comparison", variant="primary")
# 2-column grid of output boxes
cmp_outputs = []
for row_start in range(0, len(_initial_versions), 2):
with gr.Row():
for v in _initial_versions[row_start:row_start + 2]:
cmp_outputs.append(gr.Textbox(label=v, lines=10, interactive=False))
cmp_run.click(
_compare_fn,
inputs=[cmp_prompt, cmp_models, cmp_mode,
cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp,
cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw],
outputs=cmp_outputs,
)
# ── Benchmark ─────────────────────────────────────────────────────────────
with gr.Tab("Benchmark"):
gr.Markdown(
"Evaluate models on standard benchmarks.\n\n"
"- **BLiMP** β€” grammaticality minimal pairs (accuracy)\n"
"- **WikiText-2** β€” LM perplexity (lower = better)\n"
"- **ARC-Easy** β€” multiple-choice science QA (accuracy)"
)
with gr.Row():
bench_type = gr.Radio(
choices=list(BENCHMARKS.keys()), value="arc_easy", label="Benchmark"
)
bench_models = gr.CheckboxGroup(
choices=_initial_versions,
value=[_initial_versions[0]] if _initial_versions else [],
label="Models",
)
bench_samples = gr.Slider(10, 500, value=100, step=10, label="Max samples (fewer = faster)")
bench_run = gr.Button("Run benchmark", variant="primary")
with gr.Row():
bench_log = gr.Textbox(label="Progress log", lines=12, interactive=False)
bench_plot = gr.Image(label="Results chart", type="filepath")
bench_run.click(
_benchmark_fn,
inputs=[bench_type, bench_models, bench_samples],
outputs=[bench_log, bench_plot],
)
if __name__ == "__main__":
demo.launch(theme=_HF_THEME)