| |
| |
| |
| |
|
|
| import random |
| import torch |
| import numpy as np |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import T5Config, T5ForConditionalGeneration |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| import time |
|
|
| |
| |
| |
|
|
| TRAIN_SAMPLES = 2_000_000 |
| VAL_SAMPLES = 10_000 |
| MAX_DIGITS = 3 |
| BATCH_SIZE = 512 |
| EPOCHS = 10 |
| LR = 3e-4 |
| MAX_INPUT_LEN = 20 |
| MAX_TARGET_LEN= 12 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| SAVE_PATH = "model.pt" |
|
|
| print(f"Device: {DEVICE}") |
| print(f"GPU: {torch.cuda.get_device_name(0) if DEVICE == 'cuda' else 'None'}") |
|
|
| |
| |
| |
|
|
| CHARS = list("0123456789+-*/=") + ["<pad>", "<bos>", "<eos>"] |
| char2id = {c: i for i, c in enumerate(CHARS)} |
| id2char = {i: c for c, i in char2id.items()} |
|
|
| PAD_ID = char2id["<pad>"] |
| BOS_ID = char2id["<bos>"] |
| EOS_ID = char2id["<eos>"] |
| VOCAB_SIZE = len(CHARS) |
|
|
| def encode(text, max_len, add_bos=False, add_eos=True): |
| tokens = [] |
| if add_bos: |
| tokens.append(BOS_ID) |
| for c in text: |
| tokens.append(char2id.get(c, PAD_ID)) |
| if add_eos: |
| tokens.append(EOS_ID) |
| |
| tokens = tokens[:max_len] |
| tokens += [PAD_ID] * (max_len - len(tokens)) |
| return tokens |
|
|
| def decode(token_ids): |
| result = [] |
| for tid in token_ids: |
| if tid == EOS_ID: |
| break |
| if tid in (PAD_ID, BOS_ID): |
| continue |
| result.append(id2char.get(tid, "?")) |
| return "".join(result) |
|
|
| |
| |
| |
|
|
| def generate_sample(max_digits=3): |
| op = random.choice(["+", "-", "*", "/"]) |
| |
| if op == "+": |
| a = random.randint(0, 10**max_digits - 1) |
| b = random.randint(0, 10**max_digits - 1) |
| result = a + b |
| elif op == "-": |
| a = random.randint(0, 10**max_digits - 1) |
| b = random.randint(0, 10**max_digits - 1) |
| result = a - b |
| elif op == "*": |
| a = random.randint(0, 10**(max_digits-1) - 1) |
| b = random.randint(0, 10**(max_digits-1) - 1) |
| result = a * b |
| elif op == "/": |
| b = random.randint(1, 10**(max_digits-1) - 1) |
| result = random.randint(0, 10**(max_digits-1) - 1) |
| a = b * result |
| |
| input_str = f"{a}{op}{b}=" |
| target_str = str(result) |
| return input_str, target_str |
|
|
| def generate_dataset(n_samples, max_digits=3): |
| inputs, targets = [], [] |
| for _ in range(n_samples): |
| inp, tgt = generate_sample(max_digits) |
| inputs.append(inp) |
| targets.append(tgt) |
| return inputs, targets |
|
|
| print("Generating training data...") |
| t0 = time.time() |
| train_inputs, train_targets = generate_dataset(TRAIN_SAMPLES, MAX_DIGITS) |
| val_inputs, val_targets = generate_dataset(VAL_SAMPLES, MAX_DIGITS) |
| print(f"Done in {time.time()-t0:.1f}s") |
| print(f"Sample: '{train_inputs[0]}' → '{train_targets[0]}'") |
|
|
| |
| |
| |
|
|
| class MathDataset(Dataset): |
| def __init__(self, inputs, targets): |
| self.inputs = inputs |
| self.targets = targets |
|
|
| def __len__(self): |
| return len(self.inputs) |
|
|
| def __getitem__(self, idx): |
| inp = self.inputs[idx] |
| tgt = self.targets[idx] |
| |
| input_ids = encode(inp, MAX_INPUT_LEN, add_bos=False, add_eos=True) |
| attention_mask = [1 if t != PAD_ID else 0 for t in input_ids] |
| |
| labels = encode(tgt, MAX_TARGET_LEN, add_bos=False, add_eos=True) |
| labels = [t if t != PAD_ID else -100 for t in labels] |
| |
| decoder_input = [BOS_ID] + encode(tgt, MAX_TARGET_LEN-1, add_bos=False, add_eos=False) |
| decoder_input = decoder_input[:MAX_TARGET_LEN] |
| decoder_input += [PAD_ID] * (MAX_TARGET_LEN - len(decoder_input)) |
|
|
| return { |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), |
| "attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
| "decoder_input_ids": torch.tensor(decoder_input, dtype=torch.long), |
| "labels": torch.tensor(labels, dtype=torch.long), |
| } |
|
|
| train_dataset = MathDataset(train_inputs, train_targets) |
| val_dataset = MathDataset(val_inputs, val_targets) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, |
| num_workers=2, pin_memory=True) |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, |
| num_workers=2, pin_memory=True) |
|
|
| |
| |
| |
|
|
| config = T5Config( |
| vocab_size=VOCAB_SIZE, |
| d_model=128, |
| d_ff=256, |
| num_heads=4, |
| num_layers=3, |
| num_decoder_layers=3, |
| d_kv=32, |
| dropout_rate=0.1, |
| feed_forward_proj="relu", |
| is_encoder_decoder=True, |
| pad_token_id=PAD_ID, |
| eos_token_id=EOS_ID, |
| decoder_start_token_id=BOS_ID, |
| ) |
|
|
| model = T5ForConditionalGeneration(config).to(DEVICE) |
|
|
| scaler = torch.cuda.amp.GradScaler() |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"\nTotal parameters: {total_params/1e6:.2f}M") |
| print(f"Trainable: {trainable/1e6:.2f}M") |
|
|
| |
| |
| |
|
|
| optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01) |
| total_steps = len(train_loader) * EPOCHS |
| scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=LR/10) |
|
|
| |
| |
| |
|
|
| def evaluate(model, loader, n_examples=5): |
| model.eval() |
| correct = 0 |
| total = 0 |
| examples = [] |
| |
| with torch.no_grad(): |
| for batch in loader: |
| input_ids = batch["input_ids"].to(DEVICE) |
| attention_mask = batch["attention_mask"].to(DEVICE) |
| |
| |
| generated = model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| max_new_tokens=MAX_TARGET_LEN, |
| eos_token_id=EOS_ID, |
| pad_token_id=PAD_ID, |
| ) |
| |
| labels = batch["labels"] |
| |
| for i in range(len(input_ids)): |
| pred_ids = generated[i].cpu().tolist() |
| pred_str = decode(pred_ids) |
| |
| lbl = labels[i].tolist() |
| lbl = [t for t in lbl if t != -100] |
| true_str = decode(lbl) |
| |
| is_correct = (pred_str == true_str) |
| correct += int(is_correct) |
| total += 1 |
| |
| if len(examples) < n_examples: |
| inp_str = decode(input_ids[i].cpu().tolist()) |
| examples.append((inp_str, true_str, pred_str, is_correct)) |
| |
| accuracy = correct / total * 100 |
| return accuracy, examples |
|
|
| |
| |
| |
|
|
| print("\n" + "="*60) |
| print("TRAINING START") |
| print("="*60) |
|
|
| best_accuracy = 0.0 |
|
|
| for epoch in range(1, EPOCHS + 1): |
| model.train() |
| total_loss = 0.0 |
| steps = 0 |
| t_start = time.time() |
| |
| for batch in train_loader: |
| input_ids = batch["input_ids"].to(DEVICE) |
| attention_mask = batch["attention_mask"].to(DEVICE) |
| decoder_input_ids = batch["decoder_input_ids"].to(DEVICE) |
| labels = batch["labels"].to(DEVICE) |
| |
| optimizer.zero_grad() |
| |
| |
| with torch.cuda.amp.autocast(dtype=torch.float16): |
| outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| decoder_input_ids=decoder_input_ids, |
| labels=labels, |
| ) |
| loss = outputs.loss |
| |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
| |
| total_loss += loss.item() |
| steps += 1 |
| |
| if steps % 500 == 0: |
| avg_loss = total_loss / steps |
| elapsed = time.time() - t_start |
| print(f" Epoch {epoch} | Step {steps}/{len(train_loader)} " |
| f"| Loss: {avg_loss:.4f} | {elapsed:.0f}s") |
| |
| avg_loss = total_loss / steps |
| |
| |
| print(f"\nEpoch {epoch} done. Evaluating...") |
| accuracy, examples = evaluate(model, val_loader) |
| |
| print(f"\n{'='*60}") |
| print(f"Epoch {epoch}/{EPOCHS}") |
| print(f" Train loss: {avg_loss:.4f}") |
| print(f" Val accuracy: {accuracy:.2f}%") |
| print(f"\n Samples:") |
| for inp, true, pred, ok in examples: |
| status = "✅" if ok else "❌" |
| print(f" {status} '{inp}' → expected: '{true}', got: '{pred}'") |
| print("="*60) |
| |
| |
| if accuracy > best_accuracy: |
| best_accuracy = accuracy |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "config": config, |
| "char2id": char2id, |
| "id2char": id2char, |
| "epoch": epoch, |
| "accuracy": accuracy, |
| }, SAVE_PATH) |
| print(f" 💾 New best model saved! ({accuracy:.2f}%)") |
|
|
| print(f"\nTraining done! Best accuracy: {best_accuracy:.2f}%") |
|
|
| |
| |
| |
|
|
| def predict(model, expression): |
| model.eval() |
| inp = expression + "=" |
| input_ids = torch.tensor( |
| [encode(inp, MAX_INPUT_LEN, add_bos=False, add_eos=True)], |
| dtype=torch.long |
| ).to(DEVICE) |
| attention_mask = (input_ids != PAD_ID).long() |
| |
| with torch.no_grad(): |
| generated = model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| max_new_tokens=MAX_TARGET_LEN, |
| eos_token_id=EOS_ID, |
| pad_token_id=PAD_ID, |
| ) |
| |
| return decode(generated[0].cpu().tolist()) |
|
|
| print("\n" + "="*60) |
| print("INFERENCE TEST") |
| print("="*60) |
|
|
| test_cases = [ |
| "123+456", |
| "999-123", |
| "12*34", |
| "100/5", |
| "500+500", |
| "77*8", |
| ] |
|
|
| for expr in test_cases: |
| pred = predict(model, expr) |
| try: |
| true = str(eval(expr.replace("/", "//"))) |
| except: |
| true = "?" |
| status = "✅" if pred == true else "❌" |
| print(f" {status} {expr} = {pred} (correct: {true})") |