Spaces:
Sleeping
Sleeping
| """ | |
| Calculator LLM - A tiny transformer that solves English math problems. | |
| https://sid.sh/learn/build-your-first-llm | |
| """ | |
| import json | |
| import math | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # Model Architecture | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, embed_dim, max_seq_len=512, dropout=0.1): | |
| super().__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_seq_len, embed_dim) | |
| position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x): | |
| x = x + self.pe[:, :x.size(1), :] | |
| return self.dropout(x) | |
| class TokenEmbedding(nn.Module): | |
| def __init__(self, vocab_size, embed_dim, max_seq_len, dropout=0.1): | |
| super().__init__() | |
| self.token_embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.pos_encoding = PositionalEncoding(embed_dim, max_seq_len, dropout) | |
| self.scale = math.sqrt(embed_dim) | |
| def forward(self, x): | |
| return self.pos_encoding(self.token_embedding(x) * self.scale) | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, embed_dim, num_heads, dropout=0.1): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.scale = math.sqrt(self.head_dim) | |
| def forward(self, x, mask=None): | |
| B, S, _ = x.shape | |
| Q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | |
| K = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | |
| V = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, float("-inf")) | |
| attn = self.dropout(F.softmax(scores, dim=-1)) | |
| out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, S, self.embed_dim) | |
| return self.out_proj(out) | |
| class FeedForward(nn.Module): | |
| def __init__(self, embed_dim, ff_dim, dropout=0.1): | |
| super().__init__() | |
| self.linear1 = nn.Linear(embed_dim, ff_dim) | |
| self.linear2 = nn.Linear(ff_dim, embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| return self.linear2(self.dropout(F.relu(self.linear1(x)))) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1): | |
| super().__init__() | |
| self.attention = MultiHeadAttention(embed_dim, num_heads, dropout) | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.feed_forward = FeedForward(embed_dim, ff_dim, dropout) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None): | |
| attn_output = self.attention(x, mask) | |
| x = self.norm1(x + self.dropout(attn_output)) | |
| ff_output = self.feed_forward(x) | |
| x = self.norm2(x + self.dropout(ff_output)) | |
| return x | |
| class CalculatorLLM(nn.Module): | |
| def __init__(self, vocab_size, embed_dim, num_heads, num_layers, ff_dim, max_seq_len, dropout=0.1): | |
| super().__init__() | |
| self.embedding = TokenEmbedding(vocab_size, embed_dim, max_seq_len, dropout) | |
| self.layers = nn.ModuleList([ | |
| TransformerBlock(embed_dim, num_heads, ff_dim, dropout) | |
| for _ in range(num_layers) | |
| ]) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| self.output_proj = nn.Linear(embed_dim, vocab_size) | |
| def forward(self, x): | |
| mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device) | |
| x = self.embedding(x) | |
| for layer in self.layers: | |
| x = layer(x, mask) | |
| return self.output_proj(self.norm(x)) | |
| # Tokenizer | |
| class Tokenizer: | |
| def __init__(self, vocab): | |
| self.vocab = vocab | |
| self.id_to_word = {v: k for k, v in vocab.items()} | |
| def encode(self, text): | |
| text = text.lower().strip() | |
| text = text.replace("+", " plus ").replace("-", " minus ").replace("*", " times ").replace("=", " equals ") | |
| ids = [self.vocab["[START]"]] | |
| for word in text.split(): | |
| ids.append(self.vocab.get(word, self.vocab["[UNK]"])) | |
| ids.append(self.vocab["[END]"]) | |
| return ids | |
| def decode(self, ids): | |
| special = {"[PAD]", "[START]", "[END]", "[UNK]"} | |
| return " ".join(self.id_to_word.get(i, "[UNK]") for i in ids if self.id_to_word.get(i) not in special) | |
| # Load model | |
| print("Loading model...") | |
| with open("config.json") as f: | |
| config = json.load(f) | |
| with open("vocab.json") as f: | |
| vocab = json.load(f) | |
| model = CalculatorLLM( | |
| config["vocab_size"], config["embed_dim"], config["num_heads"], | |
| config["num_layers"], config["ff_dim"], config["max_seq_len"], config.get("dropout", 0.1) | |
| ) | |
| model.load_state_dict(torch.load("model.pt", map_location="cpu", weights_only=True)) | |
| model.eval() | |
| tokenizer = Tokenizer(vocab) | |
| print("Ready!") | |
| # Inference | |
| def solve(problem): | |
| if not problem or not problem.strip(): | |
| return "" | |
| problem = problem.lower().strip() | |
| if not problem.endswith("equals"): | |
| problem += " equals" | |
| tokens = tokenizer.encode(problem)[:-1] | |
| input_ids = torch.tensor([tokens]) | |
| with torch.no_grad(): | |
| for _ in range(10): | |
| logits = model(input_ids) | |
| next_token = logits[0, -1].argmax().item() | |
| if next_token == vocab["[END]"]: | |
| break | |
| input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1) | |
| result = tokenizer.decode(input_ids[0].tolist()) | |
| return result.split("equals")[-1].strip() if "equals" in result else result | |
| # Gradio UI | |
| with gr.Blocks(title="Calculator LLM") as demo: | |
| gr.Markdown( | |
| """ | |
| # Calculator LLM | |
| A 105K parameter transformer that solves English math problems. | |
| [[model]](https://github.com/slahiri/small_calculator_model) [[tutorial]](https://sid.sh/learn/build-your-first-llm) | |
| **Limitations:** | |
| - Trained on numbers 0-99 only. Inputs or results >99 may produce errors. | |
| - Test accuracy: ~98% (trained on a small corpus). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| problem_input = gr.Textbox( | |
| label="", | |
| placeholder="Enter your problem", | |
| lines=1, | |
| show_label=False, | |
| ) | |
| run_btn = gr.Button("Run", variant="primary") | |
| with gr.Column(scale=1): | |
| answer_output = gr.Textbox( | |
| label="", | |
| placeholder="Answer will appear here", | |
| lines=1, | |
| show_label=False, | |
| interactive=False, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["two plus three"], | |
| ["seven times eight"], | |
| ["ninety minus forty five"], | |
| ["nine times nine"], | |
| ["twenty plus thirty"], | |
| ["eighty one minus forty"], | |
| ], | |
| inputs=problem_input, | |
| outputs=answer_output, | |
| fn=solve, | |
| cache_examples=True, | |
| ) | |
| run_btn.click(fn=solve, inputs=problem_input, outputs=answer_output) | |
| problem_input.submit(fn=solve, inputs=problem_input, outputs=answer_output) | |
| if __name__ == "__main__": | |
| demo.launch() | |