NanoCalc-1M / train.py
LH-Tech-AI's picture
Update train.py
f3904a4 verified
# ============================================================
# NanoCalc 1M (Mini Math Model) - T5 Seq2Seq
# ============================================================
# pip install transformers torch datasets accelerate
import random
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import T5Config, T5ForConditionalGeneration
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import time
# ============================================================
# 1. CONFIG
# ============================================================
TRAIN_SAMPLES = 2_000_000
VAL_SAMPLES = 10_000
MAX_DIGITS = 3
BATCH_SIZE = 512
EPOCHS = 10
LR = 3e-4
MAX_INPUT_LEN = 20
MAX_TARGET_LEN= 12
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_PATH = "model.pt"
print(f"Device: {DEVICE}")
print(f"GPU: {torch.cuda.get_device_name(0) if DEVICE == 'cuda' else 'None'}")
# ============================================================
# 2. TOKENIZER (Character-Level)
# ============================================================
CHARS = list("0123456789+-*/=") + ["<pad>", "<bos>", "<eos>"]
char2id = {c: i for i, c in enumerate(CHARS)}
id2char = {i: c for c, i in char2id.items()}
PAD_ID = char2id["<pad>"]
BOS_ID = char2id["<bos>"]
EOS_ID = char2id["<eos>"]
VOCAB_SIZE = len(CHARS)
def encode(text, max_len, add_bos=False, add_eos=True):
tokens = []
if add_bos:
tokens.append(BOS_ID)
for c in text:
tokens.append(char2id.get(c, PAD_ID))
if add_eos:
tokens.append(EOS_ID)
# Padding
tokens = tokens[:max_len]
tokens += [PAD_ID] * (max_len - len(tokens))
return tokens
def decode(token_ids):
result = []
for tid in token_ids:
if tid == EOS_ID:
break
if tid in (PAD_ID, BOS_ID):
continue
result.append(id2char.get(tid, "?"))
return "".join(result)
# ============================================================
# 3. DATA GENERATION
# ============================================================
def generate_sample(max_digits=3):
op = random.choice(["+", "-", "*", "/"])
if op == "+":
a = random.randint(0, 10**max_digits - 1)
b = random.randint(0, 10**max_digits - 1)
result = a + b
elif op == "-":
a = random.randint(0, 10**max_digits - 1)
b = random.randint(0, 10**max_digits - 1)
result = a - b
elif op == "*":
a = random.randint(0, 10**(max_digits-1) - 1)
b = random.randint(0, 10**(max_digits-1) - 1)
result = a * b
elif op == "/":
b = random.randint(1, 10**(max_digits-1) - 1)
result = random.randint(0, 10**(max_digits-1) - 1)
a = b * result
input_str = f"{a}{op}{b}="
target_str = str(result)
return input_str, target_str
def generate_dataset(n_samples, max_digits=3):
inputs, targets = [], []
for _ in range(n_samples):
inp, tgt = generate_sample(max_digits)
inputs.append(inp)
targets.append(tgt)
return inputs, targets
print("Generating training data...")
t0 = time.time()
train_inputs, train_targets = generate_dataset(TRAIN_SAMPLES, MAX_DIGITS)
val_inputs, val_targets = generate_dataset(VAL_SAMPLES, MAX_DIGITS)
print(f"Done in {time.time()-t0:.1f}s")
print(f"Sample: '{train_inputs[0]}' → '{train_targets[0]}'")
# ============================================================
# 4. DATASET
# ============================================================
class MathDataset(Dataset):
def __init__(self, inputs, targets):
self.inputs = inputs
self.targets = targets
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
inp = self.inputs[idx]
tgt = self.targets[idx]
input_ids = encode(inp, MAX_INPUT_LEN, add_bos=False, add_eos=True)
attention_mask = [1 if t != PAD_ID else 0 for t in input_ids]
labels = encode(tgt, MAX_TARGET_LEN, add_bos=False, add_eos=True)
labels = [t if t != PAD_ID else -100 for t in labels]
decoder_input = [BOS_ID] + encode(tgt, MAX_TARGET_LEN-1, add_bos=False, add_eos=False)
decoder_input = decoder_input[:MAX_TARGET_LEN]
decoder_input += [PAD_ID] * (MAX_TARGET_LEN - len(decoder_input))
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"decoder_input_ids": torch.tensor(decoder_input, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
train_dataset = MathDataset(train_inputs, train_targets)
val_dataset = MathDataset(val_inputs, val_targets)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=2, pin_memory=True)
# ============================================================
# 5. MODEL (~1M parameters)
# ============================================================
config = T5Config(
vocab_size=VOCAB_SIZE,
d_model=128,
d_ff=256,
num_heads=4,
num_layers=3, # Encoder layers
num_decoder_layers=3, # Decoder layers
d_kv=32,
dropout_rate=0.1,
feed_forward_proj="relu",
is_encoder_decoder=True,
pad_token_id=PAD_ID,
eos_token_id=EOS_ID,
decoder_start_token_id=BOS_ID,
)
model = T5ForConditionalGeneration(config).to(DEVICE)
scaler = torch.cuda.amp.GradScaler()
total_params = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params/1e6:.2f}M")
print(f"Trainable: {trainable/1e6:.2f}M")
# ============================================================
# 6. OPTIMIZER & SCHEDULER
# ============================================================
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
total_steps = len(train_loader) * EPOCHS
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=LR/10)
# ============================================================
# 7. EVALUATION
# ============================================================
def evaluate(model, loader, n_examples=5):
model.eval()
correct = 0
total = 0
examples = []
with torch.no_grad():
for batch in loader:
input_ids = batch["input_ids"].to(DEVICE)
attention_mask = batch["attention_mask"].to(DEVICE)
# Greedy generation
generated = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=MAX_TARGET_LEN,
eos_token_id=EOS_ID,
pad_token_id=PAD_ID,
)
labels = batch["labels"]
for i in range(len(input_ids)):
pred_ids = generated[i].cpu().tolist()
pred_str = decode(pred_ids)
lbl = labels[i].tolist()
lbl = [t for t in lbl if t != -100]
true_str = decode(lbl)
is_correct = (pred_str == true_str)
correct += int(is_correct)
total += 1
if len(examples) < n_examples:
inp_str = decode(input_ids[i].cpu().tolist())
examples.append((inp_str, true_str, pred_str, is_correct))
accuracy = correct / total * 100
return accuracy, examples
# ============================================================
# 8. TRAINING LOOP
# ============================================================
print("\n" + "="*60)
print("TRAINING START")
print("="*60)
best_accuracy = 0.0
for epoch in range(1, EPOCHS + 1):
model.train()
total_loss = 0.0
steps = 0
t_start = time.time()
for batch in train_loader:
input_ids = batch["input_ids"].to(DEVICE)
attention_mask = batch["attention_mask"].to(DEVICE)
decoder_input_ids = batch["decoder_input_ids"].to(DEVICE)
labels = batch["labels"].to(DEVICE)
optimizer.zero_grad()
# Mixed Precision
with torch.cuda.amp.autocast(dtype=torch.float16):
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
labels=labels,
)
loss = outputs.loss
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
total_loss += loss.item()
steps += 1
if steps % 500 == 0:
avg_loss = total_loss / steps
elapsed = time.time() - t_start
print(f" Epoch {epoch} | Step {steps}/{len(train_loader)} "
f"| Loss: {avg_loss:.4f} | {elapsed:.0f}s")
avg_loss = total_loss / steps
# Validation
print(f"\nEpoch {epoch} done. Evaluating...")
accuracy, examples = evaluate(model, val_loader)
print(f"\n{'='*60}")
print(f"Epoch {epoch}/{EPOCHS}")
print(f" Train loss: {avg_loss:.4f}")
print(f" Val accuracy: {accuracy:.2f}%")
print(f"\n Samples:")
for inp, true, pred, ok in examples:
status = "✅" if ok else "❌"
print(f" {status} '{inp}' → expected: '{true}', got: '{pred}'")
print("="*60)
# Bestes Modell speichern
if accuracy > best_accuracy:
best_accuracy = accuracy
torch.save({
"model_state_dict": model.state_dict(),
"config": config,
"char2id": char2id,
"id2char": id2char,
"epoch": epoch,
"accuracy": accuracy,
}, SAVE_PATH)
print(f" 💾 New best model saved! ({accuracy:.2f}%)")
print(f"\nTraining done! Best accuracy: {best_accuracy:.2f}%")
# ============================================================
# 9. INFERENCE - TEST
# ============================================================
def predict(model, expression):
model.eval()
inp = expression + "="
input_ids = torch.tensor(
[encode(inp, MAX_INPUT_LEN, add_bos=False, add_eos=True)],
dtype=torch.long
).to(DEVICE)
attention_mask = (input_ids != PAD_ID).long()
with torch.no_grad():
generated = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=MAX_TARGET_LEN,
eos_token_id=EOS_ID,
pad_token_id=PAD_ID,
)
return decode(generated[0].cpu().tolist())
print("\n" + "="*60)
print("INFERENCE TEST")
print("="*60)
test_cases = [
"123+456",
"999-123",
"12*34",
"100/5",
"500+500",
"77*8",
]
for expr in test_cases:
pred = predict(model, expr)
try:
true = str(eval(expr.replace("/", "//")))
except:
true = "?"
status = "✅" if pred == true else "❌"
print(f" {status} {expr} = {pred} (correct: {true})")