MinGPT / projects /chargpt /chargpt.py
Katiyar48's picture
Upload folder using huggingface_hub
4673b21 verified
raw
history blame contribute delete
No virus
4.08 kB
"""
Trains a character-level language model.
"""
import os
import sys
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from mingpt.model import GPT
from mingpt.trainer import Trainer
from mingpt.utils import set_seed, setup_logging, CfgNode as CN
# -----------------------------------------------------------------------------
def get_config():
C = CN()
# system
C.system = CN()
C.system.seed = 3407
C.system.work_dir = './out/chargpt'
# data
C.data = CharDataset.get_default_config()
# model
C.model = GPT.get_default_config()
C.model.model_type = 'gpt-mini'
# trainer
C.trainer = Trainer.get_default_config()
C.trainer.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
return C
# -----------------------------------------------------------------------------
class CharDataset(Dataset):
"""
Emits batches of characters
"""
@staticmethod
def get_default_config():
C = CN()
C.block_size = 128
return C
def __init__(self, config, data):
self.config = config
chars = sorted(list(set(data)))
data_size, vocab_size = len(data), len(chars)
print('data has %d characters, %d unique.' % (data_size, vocab_size))
self.stoi = { ch:i for i,ch in enumerate(chars) }
self.itos = { i:ch for i,ch in enumerate(chars) }
self.vocab_size = vocab_size
self.data = data
def get_vocab_size(self):
return self.vocab_size
def get_block_size(self):
return self.config.block_size
def __len__(self):
return len(self.data) - self.config.block_size
def __getitem__(self, idx):
# grab a chunk of (block_size + 1) characters from the data
chunk = self.data[idx:idx + self.config.block_size + 1]
# encode every character to an integer
dix = [self.stoi[s] for s in chunk]
# return as tensors
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
return x, y
# -----------------------------------------------------------------------------
if __name__ == '__main__':
# get default config and overrides from the command line, if any
config = get_config()
config.merge_from_args(sys.argv[1:])
print(config)
setup_logging(config)
set_seed(config.system.seed)
# construct the training dataset
text = open('input.txt', 'r').read() # don't worry we won't run out of file handles
train_dataset = CharDataset(config.data, text)
# construct the model
config.model.vocab_size = train_dataset.get_vocab_size()
config.model.block_size = train_dataset.get_block_size()
model = GPT(config.model)
# construct the trainer object
trainer = Trainer(config.trainer, model, train_dataset)
# iteration callback
def batch_end_callback(trainer):
if trainer.iter_num % 10 == 0:
print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
if trainer.iter_num % 500 == 0:
# evaluate both the train and test score
model.eval()
with torch.no_grad():
# sample from the model...
context = "O God, O God!"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = model.generate(x, 500, temperature=1.0, do_sample=True, top_k=10)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)
# save the latest model
print("saving model")
ckpt_path = os.path.join(config.system.work_dir, "model.pt")
torch.save(model.state_dict(), ckpt_path)
# revert model to training mode
model.train()
trainer.set_callback('on_batch_end', batch_end_callback)
# run the optimization
trainer.run()