ksimplex-llm-prototype / inference.py
AbstractPhil's picture
Create inference.py
3c6b358 verified
#!/usr/bin/env python3
"""
K-Simplex Language Model - Inference Script
Loads a trained k-simplex LLM checkpoint and generates text using
geometrically-validated autoregressive sampling.
Usage:
python inference.py --checkpoint checkpoint_epoch_008.pt --prompt "ROMEO: "
python inference.py --repo AbstractPhil/ksimplex-llm-prototype --prompt "To be or not"
"""
import argparse
import json
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
from pathlib import Path
from huggingface_hub import hf_hub_download
# =============================================================================
# GEOMETRIC CORE
# =============================================================================
def factorial(n: int) -> int:
return math.factorial(n)
def cayley_menger_volume_squared(vertices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute squared volume via Cayley-Menger determinant.
Args:
vertices: [*, nv, edim] vertex coordinates
Returns:
d2: [*, n_pairs] squared distances
vol2: [*] squared volume
"""
nv = vertices.shape[-2]
k = nv - 1 # simplex dimension
# Pairwise squared distances
diff = vertices.unsqueeze(-2) - vertices.unsqueeze(-3) # [*, nv, nv, edim]
d2_matrix = (diff ** 2).sum(-1) # [*, nv, nv]
# Extract upper triangle (pairs)
idx = torch.triu_indices(nv, nv, offset=1)
d2 = d2_matrix[..., idx[0], idx[1]] # [*, n_pairs]
# Build Cayley-Menger matrix
batch_shape = vertices.shape[:-2]
size = nv + 1
cm = torch.zeros(*batch_shape, size, size, device=vertices.device, dtype=vertices.dtype)
# First row/col: [0, 1, 1, ..., 1]
cm[..., 0, 1:] = 1.0
cm[..., 1:, 0] = 1.0
# Fill distance submatrix
cm[..., 1:, 1:] = d2_matrix
# Diagonal of distance submatrix is 0 (already set)
# Determinant
det = torch.linalg.det(cm)
# Volume formula: Vol² = (-1)^(k+1) * det(CM) / (2^k * (k!)²)
sign = (-1) ** (k + 1)
denom = (2 ** k) * (factorial(k) ** 2)
vol2 = sign * det / denom
return d2, vol2
# =============================================================================
# MODEL COMPONENTS
# =============================================================================
class SimplexTemplate(nn.Module):
"""Generates regular simplex template vertices."""
def __init__(self, k: int, edim: int, scale: float = 1.0):
super().__init__()
self.k = k
self.nv = k + 1
self.edim = edim
# Regular simplex vertices (equilateral)
vertices = torch.zeros(self.nv, edim)
for i in range(self.nv):
angle = 2 * math.pi * i / self.nv
vertices[i, 0] = scale * math.cos(angle)
if edim > 1:
vertices[i, 1] = scale * math.sin(angle)
if edim > 2:
vertices[i, 2] = scale * 0.3 * math.cos(angle * 2)
for d in range(3, edim):
vertices[i, d] = scale * 0.1 * math.sin(angle * (d + 1))
self.register_buffer('template', vertices)
def forward(self) -> torch.Tensor:
return self.template
class KSimplexChannel(nn.Module):
"""Single k-simplex channel with geometric validation."""
def __init__(self, k: int, edim: int, hidden: int, feat_dim: int, base_deform: float = 0.05):
super().__init__()
self.k = k
self.nv = k + 1
self.edim = edim
self.feat_dim = feat_dim
self.base_deform = base_deform
# Template
self.template = SimplexTemplate(k, edim)
# Projections
self._to_coords = nn.Linear(hidden, self.nv * edim)
self._to_feats = nn.Linear(hidden, self.nv * feat_dim)
# Geometry dimension: n_pairs + 1 (vol²)
n_pairs = (self.nv * (self.nv - 1)) // 2
self.geo_dim = n_pairs + 1
# Geometric gate
self._geo_gate = nn.Sequential(
nn.Linear(self.geo_dim, feat_dim),
nn.Sigmoid()
)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x: [*, hidden]
Returns:
out: [*, feat_dim + geo_dim] gated features + geometry
vol2: [*] squared volume for validity loss
mean_d2: [*] mean squared distance
"""
# Vertex coordinates
coords = self._to_coords(x).unflatten(-1, (self.nv, self.edim))
verts = self.template() + self.base_deform * coords
# Vertex features
vert_feats = self._to_feats(x).unflatten(-1, (self.nv, self.feat_dim))
# Cayley-Menger
d2, vol2 = cayley_menger_volume_squared(verts)
# Geometry vector
geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)
# Gate features by geometry
gate = self._geo_gate(geo)
validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1)
# Aggregate vertex features
feat_agg = vert_feats.mean(dim=-2) * gate * validity
# Output
out = torch.cat([feat_agg, geo], dim=-1)
return out, vol2, d2.mean(dim=-1)
class TokenToKChannels(nn.Module):
"""Project token embeddings to k-simplex channels."""
def __init__(self, embed_dim: int, hidden: int, depth: int, edim: int, feat_dim: int):
super().__init__()
self.depth = depth
self._proj = nn.Linear(embed_dim, hidden)
self._channels = nn.ModuleList([
KSimplexChannel(k=k+1, edim=edim, hidden=hidden, feat_dim=feat_dim)
for k in range(depth)
])
# Compute output dimension (max across k-levels, then pad)
self.out_dims = [ch.feat_dim + ch.geo_dim for ch in self._channels]
self.max_dim = max(self.out_dims)
# Padding projections to equalize dimensions
self._pads = nn.ModuleList([
nn.Linear(d, self.max_dim) if d != self.max_dim else nn.Identity()
for d in self.out_dims
])
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
"""
Args:
x: [B, T, embed_dim]
Returns:
out: [B, T, K, max_dim]
vol2_list: list of [B, T] per k
d2_list: list of [B, T] per k
"""
h = self._proj(x) # [B, T, hidden]
outputs = []
vol2_list = []
d2_list = []
for ch, pad in zip(self._channels, self._pads):
out, vol2, d2 = ch(h)
outputs.append(pad(out))
vol2_list.append(vol2)
d2_list.append(d2)
# Stack: [B, T, K, max_dim]
out = torch.stack(outputs, dim=-2)
return out, vol2_list, d2_list
class KChannelCrossAttention(nn.Module):
"""Cross-attention between k-levels at each position."""
def __init__(self, dim: int, num_heads: int = 4, dropout: float = 0.1):
super().__init__()
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, T, K, D]
Returns:
[B, T, K, D]
"""
B, T, K, D = x.shape
# Reshape to [B*T, K, D] - attention across K dimension
x_flat = x.view(B * T, K, D)
# Self-attention across k-levels
attn_out, _ = self.attn(x_flat, x_flat, x_flat)
# Residual + norm
out = self.norm(x_flat + attn_out)
return out.view(B, T, K, D)
class CausalSequenceAttention(nn.Module):
"""Causal attention across sequence positions."""
def __init__(self, dim: int, num_heads: int, max_seq_len: int, dropout: float = 0.1):
super().__init__()
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm = nn.LayerNorm(dim)
# Causal mask
mask = torch.tril(torch.ones(max_seq_len, max_seq_len)).bool()
self.register_buffer('_causal_mask', mask)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, T, K, D]
Returns:
[B, T, K, D]
"""
B, T, K, D = x.shape
# Flatten K into D: [B, T, K*D]
x_flat = x.view(B, T, K * D)
# Causal mask
mask = self._causal_mask[:T, :T]
attn_mask = ~mask # True = masked
# Self-attention across sequence
attn_out, _ = self.attn(
x_flat, x_flat, x_flat,
attn_mask=attn_mask.float().masked_fill(attn_mask, float('-inf'))
)
# Residual + norm
out = self.norm(x_flat + attn_out)
return out.view(B, T, K, D)
class GeoBlock(nn.Module):
"""Geometric block: k-channel attention + causal sequence attention + MLP."""
def __init__(self, dim: int, num_heads: int, max_seq_len: int, depth: int, dropout: float = 0.1):
super().__init__()
self.k_attn = KChannelCrossAttention(dim, num_heads=4, dropout=dropout)
self.seq_attn = CausalSequenceAttention(dim, num_heads, max_seq_len, dropout)
self.mlp = nn.Sequential(
nn.Linear(dim * depth, dim * depth * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * depth * 4, dim * depth),
nn.Dropout(dropout),
)
self.mlp_norm = nn.LayerNorm(dim * depth)
self.depth = depth
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, T, K, D]
Returns:
[B, T, K, D]
"""
# K-channel attention
x = self.k_attn(x)
# Sequence attention
x = self.seq_attn(x)
# MLP on flattened k-channels
B, T, K, D = x.shape
x_flat = x.view(B, T, K * D)
x_flat = self.mlp_norm(x_flat + self.mlp(x_flat))
return x_flat.view(B, T, K, D)
class KSimplexLM(nn.Module):
"""K-Simplex Language Model."""
def __init__(
self,
vocab_size: int = 50257,
max_seq_len: int = 256,
embed_dim: int = 384,
depth: int = 4,
edim: int = 16,
feat_dim: int = 96,
hidden: int = 384,
num_heads: int = 8,
num_blocks: int = 8,
dropout: float = 0.1,
):
super().__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.depth = depth
# Token embedding
self.embed = nn.Embedding(vocab_size, embed_dim)
self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
self.embed_drop = nn.Dropout(dropout)
# Token to k-channels
self.to_k_channels = TokenToKChannels(embed_dim, hidden, depth, edim, feat_dim)
# Geometric blocks
k_dim = self.to_k_channels.max_dim
self.blocks = nn.ModuleList([
GeoBlock(k_dim, num_heads, max_seq_len, depth, dropout)
for _ in range(num_blocks)
])
# LM head
self.ln_f = nn.LayerNorm(k_dim * depth)
self.lm_head = nn.Linear(k_dim * depth, vocab_size, bias=False)
# Weight tying
# self.lm_head.weight = self.embed.weight # Optional
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
"""
Args:
x: [B, T] token indices
Returns:
logits: [B, T, vocab_size]
geo_info: dict with vol2, d2 per k-level
"""
B, T = x.shape
# Embeddings
pos = torch.arange(T, device=x.device).unsqueeze(0)
h = self.embed(x) + self.pos_embed(pos)
h = self.embed_drop(h)
# To k-channels
h, vol2_list, d2_list = self.to_k_channels(h)
# Geo blocks
for block in self.blocks:
h = block(h)
# LM head
h_flat = h.view(B, T, -1)
h_flat = self.ln_f(h_flat)
logits = self.lm_head(h_flat)
geo_info = {
'vol2': vol2_list,
'd2': d2_list,
}
return logits, geo_info
# =============================================================================
# INFERENCE UTILITIES
# =============================================================================
def load_model(
checkpoint_path: str = None,
repo_id: str = None,
device: str = None,
) -> tuple[KSimplexLM, tiktoken.Encoding]:
"""
Load model from checkpoint or HuggingFace Hub.
Args:
checkpoint_path: Local path to checkpoint
repo_id: HuggingFace repo ID (e.g., "AbstractPhil/ksimplex-llm-prototype")
device: Device to load to
Returns:
model: KSimplexLM
tokenizer: tiktoken encoding
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load checkpoint
if repo_id:
checkpoint_path = hf_hub_download(repo_id, "checkpoint_latest.pt")
config_path = hf_hub_download(repo_id, "config.json")
with open(config_path) as f:
config = json.load(f)
elif checkpoint_path:
checkpoint = torch.load(checkpoint_path, map_location=device)
config = checkpoint.get('config', {}).get('model', {})
else:
raise ValueError("Must provide checkpoint_path or repo_id")
# Build model
model = KSimplexLM(
vocab_size=config.get('vocab_size', 50257),
max_seq_len=config.get('max_seq_len', 256),
embed_dim=config.get('embed_dim', 384),
depth=config.get('depth', 4),
edim=config.get('edim', 16),
feat_dim=config.get('feat_dim', 96),
hidden=config.get('hidden', 384),
num_heads=config.get('num_heads', 8),
num_blocks=config.get('num_blocks', 8),
dropout=0.0, # No dropout at inference
)
# Load weights
if repo_id:
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint.get('model_state_dict', checkpoint)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
# Tokenizer
tokenizer = tiktoken.get_encoding("gpt2")
return model, tokenizer
@torch.no_grad()
def generate(
model: KSimplexLM,
tokenizer: tiktoken.Encoding,
prompt: str,
max_tokens: int = 100,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
device: str = None,
) -> str:
"""
Generate text from prompt.
Args:
model: KSimplexLM model
tokenizer: tiktoken encoding
prompt: Input text prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_k: Top-k sampling
top_p: Nucleus sampling threshold
device: Device
Returns:
Generated text including prompt
"""
if device is None:
device = next(model.parameters()).device
# Encode prompt
tokens = tokenizer.encode(prompt)
tokens = torch.tensor([tokens], dtype=torch.long, device=device)
# Generate
for _ in range(max_tokens):
# Truncate to max_seq_len
if tokens.shape[1] > model.max_seq_len:
tokens = tokens[:, -model.max_seq_len:]
# Forward
logits, geo_info = model(tokens)
logits = logits[:, -1, :] # Last position
# Temperature
logits = logits / temperature
# Top-k
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Top-p (nucleus)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
# Sample
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append
tokens = torch.cat([tokens, next_token], dim=1)
# Stop on EOS (optional)
if next_token.item() == tokenizer.eot_token:
break
# Decode
return tokenizer.decode(tokens[0].tolist())
@torch.no_grad()
def analyze_geometry(
model: KSimplexLM,
tokenizer: tiktoken.Encoding,
text: str,
device: str = None,
) -> dict:
"""
Analyze geometric properties of text encoding.
Args:
model: KSimplexLM model
tokenizer: tiktoken encoding
text: Input text
device: Device
Returns:
Dictionary with geometric statistics
"""
if device is None:
device = next(model.parameters()).device
tokens = tokenizer.encode(text)
tokens = torch.tensor([tokens], dtype=torch.long, device=device)
_, geo_info = model(tokens)
stats = {}
for k, (vol2, d2) in enumerate(zip(geo_info['vol2'], geo_info['d2']), 1):
vol2_np = vol2.cpu().numpy()
d2_np = d2.cpu().numpy()
stats[f'k{k}'] = {
'vol2_mean': float(vol2_np.mean()),
'vol2_std': float(vol2_np.std()),
'vol2_min': float(vol2_np.min()),
'vol2_max': float(vol2_np.max()),
'validity_rate': float((vol2_np > 0).mean()),
'd2_mean': float(d2_np.mean()),
}
return stats
# =============================================================================
# CLI
# =============================================================================
def main():
parser = argparse.ArgumentParser(description='K-Simplex LLM Inference')
parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file')
parser.add_argument('--repo', type=str, default='AbstractPhil/ksimplex-llm-prototype',
help='HuggingFace repo ID')
parser.add_argument('--prompt', type=str, default='ROMEO: ',
help='Text prompt')
parser.add_argument('--max_tokens', type=int, default=100,
help='Maximum tokens to generate')
parser.add_argument('--temperature', type=float, default=0.8,
help='Sampling temperature')
parser.add_argument('--top_k', type=int, default=50,
help='Top-k sampling')
parser.add_argument('--top_p', type=float, default=0.9,
help='Nucleus sampling threshold')
parser.add_argument('--analyze', action='store_true',
help='Analyze geometric properties instead of generating')
args = parser.parse_args()
print("Loading model...")
model, tokenizer = load_model(
checkpoint_path=args.checkpoint,
repo_id=args.repo if not args.checkpoint else None,
)
print(f"Model loaded on {next(model.parameters()).device}")
if args.analyze:
print(f"\nAnalyzing: {args.prompt}")
stats = analyze_geometry(model, tokenizer, args.prompt)
for k, kstats in stats.items():
print(f"\n{k}:")
for name, value in kstats.items():
print(f" {name}: {value:.6f}")
else:
print(f"\nGenerating from: {args.prompt}")
text = generate(
model, tokenizer, args.prompt,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
print("\n" + "=" * 60)
print(text)
print("=" * 60)
if __name__ == '__main__':
main()