import torch import argparse from torch.nn import functional as F import time from attention_head import AttentionHead,Head, MultiHeadAttention, TransFormerBlock torch.manual_seed(1337) def get_batch(batch_size, dataset, block_size): sample = torch.randint(high=len(dataset)- (block_size +1), size = (batch_size, 1)) xb = torch.zeros(batch_size,block_size, dtype=torch.long) yb = torch.zeros(batch_size,block_size, dtype=torch.long) for idx, sample_index in enumerate(sample): xb[idx,:] = dataset[sample_index:sample_index+block_size] yb[idx,:] = dataset[sample_index+1:sample_index+block_size+1] return xb, yb @torch.no_grad() def eval(model, batch_size, block_size, dataset): xb, yb = get_batch(batch_size, dataset, block_size) logits, loss = model(xb, yb) return loss.item() def train(model, optimizer, batch_size, block_size, train_ds, val_ds, steps): sumloss = 0 for _ in range(1,steps+1): xb, yb = get_batch(batch_size, train_ds, block_size) logits, loss = model(xb, yb) sumloss += loss.item() optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() if _ % 1000 == 0: val_loss = eval(model, 30, block_size, val_ds,) print(f"step {_} || train loss: {sumloss/1000} , val loss: {val_loss}") sumloss = 0 class Transformer(torch.nn.Module): def __init__(self,vocab_size,n_tf=3, block_size=8,token_embed_dim=16) -> None: super().__init__() self.block_size=block_size self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim) self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim) self.tf_blocks = torch.nn.Sequential( *[TransFormerBlock(token_embed_dim, block_size, 16, 8) for _ in range(n_tf)] ) self.lm_head = torch.nn.Linear(128, vocab_size) def forward(self, idx, targets=None): B,T=idx.shape token_embed = self.token_embedding_table(idx) positional_embed = self.positional_embedding(torch.arange(T)) x = token_embed+positional_embed x= self.tf_blocks(x) logits = self.lm_head(x) if targets is None: loss = None else: B, T, C = logits.shape logits = logits.view(B*T, C) targets = targets.view(B*T) loss = F.cross_entropy(logits, targets) return logits, loss def generate(self, idx, max_new_tokens): # idx is (B, T) array of indices in the current context for _ in range(max_new_tokens): # get the predictions logits, loss = self(idx[:, -self.block_size:]) # focus only on the last time step logits = logits[:, -1, :] # becomes (B, C) # apply softmax to get probabilities probs = F.softmax(logits, dim=-1) # (B, C) # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) # append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) return idx class BigramLanguageModel(torch.nn.Module): def __init__(self, vocab_size,block_size=8,token_embed_dim=16): super().__init__() self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim) self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim) self.attention_head = MultiHeadAttention(n_embed=token_embed_dim, timesteps=block_size, head_size=token_embed_dim//4, # does head size have to == token embed_dim / n heads? I think it does n_heads=4) # (in = (B, T, C), out = B,T,C) self.lm_head = torch.nn.Linear(token_embed_dim, vocab_size) # (in B, T, C, out = B, T, C, performs linear on C) self.block_size = block_size def forward(self, idx, targets=None): B, T = idx.shape # idx and targets are both (B,T) tensor of integers token_embedding = self.token_embedding_table(idx) # (B,T, in), (B,T,embed_dim out) positional_embedding = self.positional_embedding(torch.arange(T,dtype=torch.long)) # (T, embed_dim) x = token_embedding + positional_embedding # (B,T,embed_dim) x = self.attention_head(x) # (B,T,embed_dim) logits = self.lm_head(x) if targets is None: loss = None else: B, T, C = logits.shape logits = logits.view(B*T, C) targets = targets.view(B*T) loss = F.cross_entropy(logits, targets) return logits, loss def generate(self, idx, max_new_tokens): # idx is (B, T) array of indices in the current context for _ in range(max_new_tokens): # get the predictions logits, loss = self(idx[:, -self.block_size:]) # focus only on the last time step logits = logits[:, -1, :] # becomes (B, C) # apply softmax to get probabilities probs = F.softmax(logits, dim=-1) # (B, C) # sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) # append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) return idx def main(): ######################## #PARAMS################# batch_size = 32 block_size= 128 n_embed = 128 n_tf = 3 n_heads=8 head_size=16 vocab_size=65 ######################## parser = argparse.ArgumentParser( description='Train a bigram language model' ) parser.add_argument('-c', '--cont', action='store_true',) parser.add_argument('-e', '--eval', action='store_true',) parser.add_argument('-v', '--verbose',action='store_true') text = open('input.txt').read() characters = sorted(list(set(text))) decoder = dict(enumerate(characters)) encoder = {v: k for k, v in decoder.items()} encode = lambda x: encoder[x] decode = lambda x: decoder[x] text_tensor = torch.tensor([encode(c) for c in text]) train_tensor = text_tensor[:int(len(text_tensor) * 0.8)] val_tensor = text_tensor[int(len(text_tensor) * 0.8):] model = Transformer(vocab_size=vocab_size, n_tf=n_tf,block_size=block_size, token_embed_dim=n_embed) if parser.parse_args().verbose: print(model) num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad) print('parameters:', num_params) # if -c is passed we will load the model from the file if parser.parse_args().cont: state_dict = torch.load('transformer.pth') model.load_state_dict(state_dict) optimizer = torch.optim.Adam(model.parameters(), lr=3e-5) s = time.time() if not parser.parse_args().eval: try: train(model, optimizer, batch_size=batch_size, block_size=block_size, train_ds=train_tensor, val_ds=val_tensor,steps= 100000) except KeyboardInterrupt: torch.save(model.state_dict(), 'transformer.pth') exit() if parser.parse_args().verbose: print('training time: ', time.time() - s) torch.save(model.state_dict(), 'transformer.pth') model.eval() print(''.join([decode(c) for c in model.generate(torch.zeros(1,32, dtype=torch.long), 1000)[0].tolist()[32:]])) # 2.57 adam if __name__ == '__main__': main()