eliot
training file and weights
787be42
import torch
import argparse
from torch.nn import functional as F
import time
from attention_head import AttentionHead,Head, MultiHeadAttention, TransFormerBlock
torch.manual_seed(1337)
def get_batch(batch_size, dataset, block_size):
sample = torch.randint(high=len(dataset)- (block_size +1), size = (batch_size, 1))
xb = torch.zeros(batch_size,block_size, dtype=torch.long)
yb = torch.zeros(batch_size,block_size, dtype=torch.long)
for idx, sample_index in enumerate(sample):
xb[idx,:] = dataset[sample_index:sample_index+block_size]
yb[idx,:] = dataset[sample_index+1:sample_index+block_size+1]
return xb, yb
@torch.no_grad()
def eval(model, batch_size, block_size, dataset):
xb, yb = get_batch(batch_size, dataset, block_size)
logits, loss = model(xb, yb)
return loss.item()
def train(model, optimizer, batch_size, block_size, train_ds, val_ds, steps):
sumloss = 0
for _ in range(1,steps+1):
xb, yb = get_batch(batch_size, train_ds, block_size)
logits, loss = model(xb, yb)
sumloss += loss.item()
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if _ % 1000 == 0:
val_loss = eval(model, 30, block_size, val_ds,)
print(f"step {_} || train loss: {sumloss/1000} , val loss: {val_loss}")
sumloss = 0
class Transformer(torch.nn.Module):
def __init__(self,vocab_size,n_tf=3, block_size=8,token_embed_dim=16) -> None:
super().__init__()
self.block_size=block_size
self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim)
self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim)
self.tf_blocks = torch.nn.Sequential(
*[TransFormerBlock(token_embed_dim, block_size, 16, 8) for _ in range(n_tf)]
)
self.lm_head = torch.nn.Linear(128, vocab_size)
def forward(self, idx, targets=None):
B,T=idx.shape
token_embed = self.token_embedding_table(idx)
positional_embed = self.positional_embedding(torch.arange(T))
x = token_embed+positional_embed
x= self.tf_blocks(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# get the predictions
logits, loss = self(idx[:, -self.block_size:])
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
class BigramLanguageModel(torch.nn.Module):
def __init__(self, vocab_size,block_size=8,token_embed_dim=16):
super().__init__()
self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim)
self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim)
self.attention_head = MultiHeadAttention(n_embed=token_embed_dim,
timesteps=block_size,
head_size=token_embed_dim//4, # does head size have to == token embed_dim / n heads? I think it does
n_heads=4) # (in = (B, T, C), out = B,T,C)
self.lm_head = torch.nn.Linear(token_embed_dim, vocab_size) # (in B, T, C, out = B, T, C, performs linear on C)
self.block_size = block_size
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are both (B,T) tensor of integers
token_embedding = self.token_embedding_table(idx) # (B,T, in), (B,T,embed_dim out)
positional_embedding = self.positional_embedding(torch.arange(T,dtype=torch.long)) # (T, embed_dim)
x = token_embedding + positional_embedding # (B,T,embed_dim)
x = self.attention_head(x) # (B,T,embed_dim)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# get the predictions
logits, loss = self(idx[:, -self.block_size:])
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
def main():
########################
#PARAMS#################
batch_size = 32
block_size= 128
n_embed = 128
n_tf = 3
n_heads=8
head_size=16
vocab_size=65
########################
parser = argparse.ArgumentParser(
description='Train a bigram language model'
)
parser.add_argument('-c', '--cont', action='store_true',)
parser.add_argument('-e', '--eval', action='store_true',)
parser.add_argument('-v', '--verbose',action='store_true')
text = open('input.txt').read()
characters = sorted(list(set(text)))
decoder = dict(enumerate(characters))
encoder = {v: k for k, v in decoder.items()}
encode = lambda x: encoder[x]
decode = lambda x: decoder[x]
text_tensor = torch.tensor([encode(c) for c in text])
train_tensor = text_tensor[:int(len(text_tensor) * 0.8)]
val_tensor = text_tensor[int(len(text_tensor) * 0.8):]
model = Transformer(vocab_size=vocab_size, n_tf=n_tf,block_size=block_size, token_embed_dim=n_embed)
if parser.parse_args().verbose:
print(model)
num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('parameters:', num_params)
# if -c is passed we will load the model from the file
if parser.parse_args().cont:
state_dict = torch.load('transformer.pth')
model.load_state_dict(state_dict)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
s = time.time()
if not parser.parse_args().eval:
try:
train(model, optimizer, batch_size=batch_size, block_size=block_size, train_ds=train_tensor, val_ds=val_tensor,steps= 100000)
except KeyboardInterrupt:
torch.save(model.state_dict(), 'transformer.pth')
exit()
if parser.parse_args().verbose:
print('training time: ', time.time() - s)
torch.save(model.state_dict(), 'transformer.pth')
model.eval()
print(''.join([decode(c) for c in model.generate(torch.zeros(1,32, dtype=torch.long), 1000)[0].tolist()[32:]]))
# 2.57 adam
if __name__ == '__main__':
main()