""" Calculator LLM - A tiny transformer for solving English math problems. Built from scratch following: https://sid.sh/learn/build-your-first-llm """ import torch import torch.nn as nn import torch.nn.functional as F import math import json class PositionalEncoding(nn.Module): """Adds positional information to embeddings using sine/cosine waves.""" def __init__(self, embed_dim, max_seq_len=512, dropout=0.1): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_seq_len, embed_dim) position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pe", pe) def forward(self, x): x = x + self.pe[:, : x.size(1), :] return self.dropout(x) class TokenEmbedding(nn.Module): """Converts token IDs to embedding vectors with positional encoding.""" def __init__(self, vocab_size, embed_dim, max_seq_len, dropout=0.1): super().__init__() self.embed_dim = embed_dim self.token_embedding = nn.Embedding(vocab_size, embed_dim) self.pos_encoding = PositionalEncoding(embed_dim, max_seq_len, dropout) self.scale = math.sqrt(embed_dim) def forward(self, x): x = self.token_embedding(x) * self.scale x = self.pos_encoding(x) return x class MultiHeadAttention(nn.Module): """Multi-head self-attention mechanism.""" def __init__(self, embed_dim, num_heads, dropout=0.1): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.head_dim) def forward(self, x, mask=None): batch_size, seq_len, _ = x.shape Q = ( self.q_proj(x) .view(batch_size, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) ) K = ( self.k_proj(x) .view(batch_size, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) ) V = ( self.v_proj(x) .view(batch_size, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) ) scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale if mask is not None: scores = scores.masked_fill(mask == 0, float("-inf")) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, V) attn_output = ( attn_output.transpose(1, 2) .contiguous() .view(batch_size, seq_len, self.embed_dim) ) return self.out_proj(attn_output), attn_weights class FeedForward(nn.Module): """Position-wise feed-forward network.""" def __init__(self, embed_dim, ff_dim, dropout=0.1): super().__init__() self.linear1 = nn.Linear(embed_dim, ff_dim) self.linear2 = nn.Linear(ff_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.linear2(self.dropout(F.relu(self.linear1(x)))) class TransformerBlock(nn.Module): """A single transformer decoder block.""" def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1): super().__init__() self.attention = MultiHeadAttention(embed_dim, num_heads, dropout) self.norm1 = nn.LayerNorm(embed_dim) self.feed_forward = FeedForward(embed_dim, ff_dim, dropout) self.norm2 = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): attn_output, attn_weights = self.attention(x, mask) x = self.norm1(x + self.dropout(attn_output)) ff_output = self.feed_forward(x) x = self.norm2(x + self.dropout(ff_output)) return x, attn_weights def create_causal_mask(seq_len): """Create a causal mask to prevent attending to future tokens.""" mask = torch.tril(torch.ones(seq_len, seq_len)) return mask.unsqueeze(0).unsqueeze(0) class CalculatorLLM(nn.Module): """A tiny transformer LLM for solving English math problems.""" def __init__( self, vocab_size, embed_dim, num_heads, num_layers, ff_dim, max_seq_len, dropout=0.1 ): super().__init__() self.embed_dim = embed_dim self.max_seq_len = max_seq_len self.embedding = TokenEmbedding(vocab_size, embed_dim, max_seq_len, dropout) self.layers = nn.ModuleList( [ TransformerBlock(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers) ] ) self.norm = nn.LayerNorm(embed_dim) self.output_proj = nn.Linear(embed_dim, vocab_size) def forward(self, x, mask=None): if mask is None: seq_len = x.size(1) mask = create_causal_mask(seq_len).to(x.device) x = self.embedding(x) for layer in self.layers: x, _ = layer(x, mask) x = self.norm(x) return self.output_proj(x) class Tokenizer: """Converts text to token IDs and back.""" def __init__(self, vocab): self.vocab = vocab self.id_to_word = {v: k for k, v in vocab.items()} def normalize(self, text): text = text.lower().strip() text = text.replace("+", " plus ").replace("-", " minus ") text = ( text.replace("*", " times ").replace("x", " times ").replace("=", " equals ") ) tens = [ "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety" ] ones = ["one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] for ten in tens: for one in ones: text = text.replace(f"{ten}{one}", f"{ten} {one}") return " ".join(text.split()) def encode(self, text, add_special_tokens=True): text = self.normalize(text) ids = [self.vocab["[START]"]] if add_special_tokens else [] for word in text.split(): ids.append(self.vocab.get(word, self.vocab["[UNK]"])) if add_special_tokens: ids.append(self.vocab["[END]"]) return ids def decode(self, ids, skip_special_tokens=True): special = {"[PAD]", "[START]", "[END]", "[UNK]"} words = [ self.id_to_word.get(id, "[UNK]") for id in ids if not (skip_special_tokens and self.id_to_word.get(id, "[UNK]") in special) ] return " ".join(words) def load_model(model_dir, device="cpu"): """Load a saved Calculator LLM model.""" with open(f"{model_dir}/config.json") as f: config = json.load(f) with open(f"{model_dir}/vocab.json") as f: vocab = json.load(f) model = CalculatorLLM( vocab_size=config["vocab_size"], embed_dim=config["embed_dim"], num_heads=config["num_heads"], num_layers=config["num_layers"], ff_dim=config["ff_dim"], max_seq_len=config["max_seq_len"], dropout=config["dropout"], ) model.load_state_dict( torch.load(f"{model_dir}/model.pt", map_location=device, weights_only=True) ) model.to(device) model.eval() tokenizer = Tokenizer(vocab) return model, tokenizer, vocab def generate(model, tokenizer, vocab, prompt, device="cpu", max_new_tokens=10): """Generate text from a prompt.""" model.eval() tokens = tokenizer.encode(prompt, add_special_tokens=True)[:-1] input_ids = torch.tensor([tokens]).to(device) with torch.no_grad(): for _ in range(max_new_tokens): logits = model(input_ids) next_token = logits[0, -1, :].argmax().item() if next_token == vocab["[END]"]: break input_ids = torch.cat( [input_ids, torch.tensor([[next_token]]).to(device)], dim=1 ) return tokenizer.decode(input_ids[0].tolist()) if __name__ == "__main__": # Example usage model, tokenizer, vocab = load_model(".") result = generate(model, tokenizer, vocab, "two plus three equals") print(f"Result: {result}")