|
""" |
|
Trains a GPT to add n-digit numbers. |
|
""" |
|
|
|
import os |
|
import sys |
|
import json |
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
from torch.utils.data.dataloader import DataLoader |
|
|
|
from mingpt.model import GPT |
|
from mingpt.trainer import Trainer |
|
from mingpt.utils import set_seed, setup_logging, CfgNode as CN |
|
|
|
|
|
|
|
def get_config(): |
|
|
|
C = CN() |
|
|
|
|
|
C.system = CN() |
|
C.system.seed = 3407 |
|
C.system.work_dir = './out/adder' |
|
|
|
|
|
C.data = AdditionDataset.get_default_config() |
|
|
|
|
|
C.model = GPT.get_default_config() |
|
C.model.model_type = 'gpt-nano' |
|
|
|
|
|
C.trainer = Trainer.get_default_config() |
|
C.trainer.learning_rate = 5e-4 |
|
|
|
return C |
|
|
|
|
|
|
|
class AdditionDataset(Dataset): |
|
""" |
|
Creates n-digit addition problems. For example, if n=2, then an example |
|
addition problem would be to add 85 + 50 = 135. This problem would be |
|
represented as the following string for the GPT: |
|
|
|
"8550531" |
|
|
|
This is because: |
|
- we are discarding the + and =, which are not necessary. We just encode the digits |
|
of the input numbers concatenated together. |
|
- the result 135 is encoded backwards to make the addition easier to learn for the |
|
GPT model, because of how the addition algorithm works. |
|
|
|
As one more example, the problem 6 + 39 = 45 would be encoded as: |
|
|
|
"0639054" |
|
|
|
where you will notice that we are padding with zeros to make sure that we always |
|
produce strings of the exact same size: n + n + (n + 1). When n=2, this is 7. |
|
At test time, we will feed in an addition problem by giving the first 2n digits, |
|
and hoping that the GPT model completes the sequence with the next (n+1) digits |
|
correctly. |
|
""" |
|
|
|
@staticmethod |
|
def get_default_config(): |
|
C = CN() |
|
C.ndigit = 2 |
|
return C |
|
|
|
def __init__(self, config, split): |
|
self.config = config |
|
self.split = split |
|
|
|
|
|
ndigit = self.config.ndigit |
|
assert ndigit <= 3, "the lines below would be very memory inefficient, in future maybe refactor to support" |
|
num = (10**ndigit)**2 |
|
rng = torch.Generator() |
|
rng.manual_seed(1337) |
|
perm = torch.randperm(num, generator=rng) |
|
num_test = min(int(num*0.2), 500) |
|
self.ixes = perm[:num_test] if split == 'test' else perm[num_test:] |
|
|
|
def get_vocab_size(self): |
|
return 10 |
|
|
|
def get_block_size(self): |
|
|
|
|
|
|
|
return 3*self.config.ndigit + 1 - 1 |
|
|
|
def __len__(self): |
|
return self.ixes.nelement() |
|
|
|
def __getitem__(self, idx): |
|
ndigit = self.config.ndigit |
|
|
|
idx = self.ixes[idx].item() |
|
nd = 10**ndigit |
|
a = idx // nd |
|
b = idx % nd |
|
|
|
c = a + b |
|
|
|
astr = f'%0{ndigit}d' % a |
|
bstr = f'%0{ndigit}d' % b |
|
cstr = (f'%0{ndigit+1}d' % c)[::-1] |
|
render = astr + bstr + cstr |
|
dix = [int(s) for s in render] |
|
|
|
x = torch.tensor(dix[:-1], dtype=torch.long) |
|
y = torch.tensor(dix[1:], dtype=torch.long) |
|
y[:ndigit*2-1] = -1 |
|
return x, y |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
config = get_config() |
|
config.merge_from_args(sys.argv[1:]) |
|
print(config) |
|
setup_logging(config) |
|
set_seed(config.system.seed) |
|
|
|
|
|
train_dataset = AdditionDataset(config.data, split='train') |
|
test_dataset = AdditionDataset(config.data, split='test') |
|
|
|
|
|
config.model.vocab_size = train_dataset.get_vocab_size() |
|
config.model.block_size = train_dataset.get_block_size() |
|
model = GPT(config.model) |
|
|
|
|
|
trainer = Trainer(config.trainer, model, train_dataset) |
|
|
|
|
|
def eval_split(trainer, split, max_batches=None): |
|
dataset = {'train':train_dataset, 'test':test_dataset}[split] |
|
ndigit = config.data.ndigit |
|
results = [] |
|
mistakes_printed_already = 0 |
|
factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device) |
|
loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False) |
|
for b, (x, y) in enumerate(loader): |
|
x = x.to(trainer.device) |
|
|
|
d1d2 = x[:, :ndigit*2] |
|
|
|
d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) |
|
|
|
d3 = d1d2d3[:, -(ndigit+1):] |
|
d3 = d3.flip(1) |
|
|
|
d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1) |
|
d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1) |
|
d3i_pred = (d3 * factors).sum(1) |
|
d3i_gt = d1i + d2i |
|
|
|
correct = (d3i_pred == d3i_gt).cpu() |
|
for i in range(x.size(0)): |
|
results.append(int(correct[i])) |
|
if not correct[i] and mistakes_printed_already < 5: |
|
mistakes_printed_already += 1 |
|
print("GPT claims that %d + %d = %d but gt is %d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i])) |
|
if max_batches is not None and b+1 >= max_batches: |
|
break |
|
rt = torch.tensor(results, dtype=torch.float) |
|
print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean())) |
|
return rt.sum() |
|
|
|
|
|
top_score = 0 |
|
def batch_end_callback(trainer): |
|
global top_score |
|
|
|
if trainer.iter_num % 10 == 0: |
|
print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}") |
|
|
|
if trainer.iter_num % 500 == 0: |
|
|
|
train_max_batches = {1: None, 2: None, 3: 5}[config.data.ndigit] |
|
model.eval() |
|
with torch.no_grad(): |
|
train_score = eval_split(trainer, 'train', max_batches=train_max_batches) |
|
test_score = eval_split(trainer, 'test', max_batches=None) |
|
score = train_score + test_score |
|
|
|
if score > top_score: |
|
top_score = score |
|
print(f"saving model with new top score of {score}") |
|
ckpt_path = os.path.join(config.system.work_dir, "model.pt") |
|
torch.save(model.state_dict(), ckpt_path) |
|
|
|
model.train() |
|
|
|
trainer.set_callback('on_batch_end', batch_end_callback) |
|
|
|
|
|
trainer.run() |
|
|