slahiri's picture
Upload folder using huggingface_hub
e0cecae verified
"""
Calculator LLM - A tiny transformer that solves English math problems.
https://sid.sh/learn/build-your-first-llm
"""
import json
import math
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
# Model Architecture
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim, max_seq_len=512, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_seq_len, embed_dim)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size, embed_dim, max_seq_len, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_encoding = PositionalEncoding(embed_dim, max_seq_len, dropout)
self.scale = math.sqrt(embed_dim)
def forward(self, x):
return self.pos_encoding(self.token_embedding(x) * self.scale)
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.head_dim)
def forward(self, x, mask=None):
B, S, _ = x.shape
Q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = self.dropout(F.softmax(scores, dim=-1))
out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, S, self.embed_dim)
return self.out_proj(out)
class FeedForward(nn.Module):
def __init__(self, embed_dim, ff_dim, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(embed_dim, ff_dim)
self.linear2 = nn.Linear(ff_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.feed_forward = FeedForward(embed_dim, ff_dim, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_output = self.attention(x, mask)
x = self.norm1(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class CalculatorLLM(nn.Module):
def __init__(self, vocab_size, embed_dim, num_heads, num_layers, ff_dim, max_seq_len, dropout=0.1):
super().__init__()
self.embedding = TokenEmbedding(vocab_size, embed_dim, max_seq_len, dropout)
self.layers = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(embed_dim)
self.output_proj = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device)
x = self.embedding(x)
for layer in self.layers:
x = layer(x, mask)
return self.output_proj(self.norm(x))
# Tokenizer
class Tokenizer:
def __init__(self, vocab):
self.vocab = vocab
self.id_to_word = {v: k for k, v in vocab.items()}
def encode(self, text):
text = text.lower().strip()
text = text.replace("+", " plus ").replace("-", " minus ").replace("*", " times ").replace("=", " equals ")
ids = [self.vocab["[START]"]]
for word in text.split():
ids.append(self.vocab.get(word, self.vocab["[UNK]"]))
ids.append(self.vocab["[END]"])
return ids
def decode(self, ids):
special = {"[PAD]", "[START]", "[END]", "[UNK]"}
return " ".join(self.id_to_word.get(i, "[UNK]") for i in ids if self.id_to_word.get(i) not in special)
# Load model
print("Loading model...")
with open("config.json") as f:
config = json.load(f)
with open("vocab.json") as f:
vocab = json.load(f)
model = CalculatorLLM(
config["vocab_size"], config["embed_dim"], config["num_heads"],
config["num_layers"], config["ff_dim"], config["max_seq_len"], config.get("dropout", 0.1)
)
model.load_state_dict(torch.load("model.pt", map_location="cpu", weights_only=True))
model.eval()
tokenizer = Tokenizer(vocab)
print("Ready!")
# Inference
def solve(problem):
if not problem or not problem.strip():
return ""
problem = problem.lower().strip()
if not problem.endswith("equals"):
problem += " equals"
tokens = tokenizer.encode(problem)[:-1]
input_ids = torch.tensor([tokens])
with torch.no_grad():
for _ in range(10):
logits = model(input_ids)
next_token = logits[0, -1].argmax().item()
if next_token == vocab["[END]"]:
break
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
result = tokenizer.decode(input_ids[0].tolist())
return result.split("equals")[-1].strip() if "equals" in result else result
# Gradio UI
with gr.Blocks(title="Calculator LLM") as demo:
gr.Markdown(
"""
# Calculator LLM
A 105K parameter transformer that solves English math problems.
[[model]](https://github.com/slahiri/small_calculator_model) [[tutorial]](https://sid.sh/learn/build-your-first-llm)
**Limitations:**
- Trained on numbers 0-99 only. Inputs or results >99 may produce errors.
- Test accuracy: ~98% (trained on a small corpus).
"""
)
with gr.Row():
with gr.Column(scale=1):
problem_input = gr.Textbox(
label="",
placeholder="Enter your problem",
lines=1,
show_label=False,
)
run_btn = gr.Button("Run", variant="primary")
with gr.Column(scale=1):
answer_output = gr.Textbox(
label="",
placeholder="Answer will appear here",
lines=1,
show_label=False,
interactive=False,
)
gr.Examples(
examples=[
["two plus three"],
["seven times eight"],
["ninety minus forty five"],
["nine times nine"],
["twenty plus thirty"],
["eighty one minus forty"],
],
inputs=problem_input,
outputs=answer_output,
fn=solve,
cache_examples=True,
)
run_btn.click(fn=solve, inputs=problem_input, outputs=answer_output)
problem_input.submit(fn=solve, inputs=problem_input, outputs=answer_output)
if __name__ == "__main__":
demo.launch()