gpt / gpt.py
jkeisling's picture
Fix multi-device support for app.py
93c0b2e
import argparse
import datetime
import gc
import json
import math
import os
from typing import Type
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
import gradio as gr
# Define the command-line arguments
parser = argparse.ArgumentParser(description='GPT CLI')
parser.add_argument('--gui', action='store_true', help='Enable Gradio UI mode')
parser.add_argument('--config', default='./gpt_config.json',
help='Path to the config file')
subparsers = parser.add_subparsers(dest='command', help='Choose a command')
# Define the training command
train_parser = subparsers.add_parser('train', help='Train the model')
train_parser.add_argument('--load-from-restore', action='store_true',
help='Load from restore path instead of training from scratch')
# Define the evaluation command
eval_parser = subparsers.add_parser('eval', help='Evaluate the model')
eval_parser.add_argument('--data', default='./data/evaluation_data.txt',
help='Path to the evaluation data file')
# Define the inference command
infer_parser = subparsers.add_parser(
'infer', help='Generate text from the model')
infer_parser.add_argument('--text', type=str, required=True,
help='Input text for generating continuation')
infer_parser.add_argument('--length', type=int,
default=100, help='Number of characters to generate')
torch.manual_seed(1337)
# Set device to CUDA if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class GPTConfig:
def __init__(self, config_file_path):
with open(config_file_path, 'r') as f:
config = json.load(f)
# Architecture configuration
architecture_config = config['architecture']
self.embedding_dim = architecture_config['embedding_dim']
self.vocab_size = architecture_config['vocab_size']
self.context_size = architecture_config['context_size']
self.num_heads = architecture_config['num_heads']
self.num_layers = architecture_config['num_layers']
# Training configuration
training_config = config['training']
self.batch_size = training_config['batch_size']
self.training_data_path = training_config['training_data_path']
self.save_folder = training_config['save_folder']
self.learning_rate = training_config['learning_rate']
self.num_steps = training_config['num_steps']
self.val_interval = training_config['val_interval']
# Generation hyperparameters
generation_config = config['generation']
self.top_k = generation_config['top_k']
self.top_p = generation_config['top_p']
self.temp = generation_config['temp']
# Checkpoint restore configuration
self.restore_path = config['restore_path']
def encode_text(text):
# Simple dumb ASCII character-level "encoding" since all training data is ASCII.
# Consider better tokenization if moving off character-level
return ([ord(t) for t in text])
def decode_text(indices):
return ([chr(x) for x in indices])
class TextDataset(Dataset):
def __init__(self, data_tensor, context_size):
self.data_tensor = data_tensor
self.context_size = context_size
def __len__(self):
return len(self.data_tensor) - self.context_size
def __getitem__(self, index):
x = self.data_tensor[index:index + self.context_size]
y = self.data_tensor[index + 1:index + self.context_size + 1]
return x, y
def load_dataset(data_path, context_size):
with open(data_path, 'r', encoding='utf-8') as f:
text = f.read()
# Tensorify data, put it in dataset
data = torch.tensor(encode_text(text), dtype=torch.int32)
test_split_idx = int(0.8 * len(data))
val_split_idx = int(0.9 * len(data))
train_data = data[:test_split_idx]
test_data = data[test_split_idx:val_split_idx]
val_data = data[val_split_idx:]
# print(f"{len(data)} chars of data")
train_dataset = TextDataset(train_data, context_size)
test_dataset = TextDataset(test_data, context_size)
val_dataset = TextDataset(test_data, context_size)
return ((train_dataset, val_dataset, test_dataset))
class MultiheadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, device=None, dtype=None):
super(MultiheadAttention, self).__init__()
# Save variables
self.embed_dim = embed_dim
self.num_heads = num_heads
self.d_k = embed_dim // num_heads
self.Q = nn.Linear(embed_dim, embed_dim, bias=False)
self.K = nn.Linear(embed_dim, embed_dim, bias=False)
self.V = nn.Linear(embed_dim, embed_dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, query, key, value, attn_mask=None):
batch_size = query.size(0)
# Apply linear layers
q = self.Q(query) # [B, C, E]
k = self.K(key) # [B, C, E]
v = self.V(value) # [B, C, E]
# Mutate dimensions so the attention matmul can get rid of the inner d_k
# [batch_size, num_heads, C, d_k]
q = q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# [batch_size, num_heads, C, d_k]
k = k.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# [batch_size, num_heads, C, d_k]
v = v.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Get raw attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / \
math.sqrt(self.d_k) # [B, num_heads, C, C]
# Apply mask, if necessary
if attn_mask is not None:
scores = scores.masked_fill(attn_mask, float('-inf'))
# Scale by sqrt(k)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = attn @ v # [B, num_heads, C, d_k]
# Concat and project
# Swap C and num_heads, force memory to coalesce, then fuse back num_heads and d_k together
out = out.transpose(1, 2).contiguous().view(
batch_size, -1, self.embed_dim)
# Project: give attention "time to think". Maybe this should be part of a different module but whatever
out = self.out_proj(out)
return ((out, None))
class FeedForward(nn.Module):
def __init__(self, embed_dim, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(),
nn.Linear(4 * embed_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x):
return (self.net(x))
class Block(nn.Module):
"""Self-attention"""
def __init__(self, embed_dim, num_heads, mask, dropout=0.2):
super(Block, self).__init__()
self.register_buffer("mask", mask)
self.head = MultiheadAttention(
embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
# self.head = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True)
self.ffwd = FeedForward(embed_dim=embed_dim, dropout=dropout)
self.ln1 = nn.LayerNorm(embed_dim)
self.ln2 = nn.LayerNorm(embed_dim)
def forward(self, x):
# Residual connections
x = self.ln1(x)
attn_output, _ = self.head(x, x, x, attn_mask=self.mask)
x = x + attn_output
out = x + self.ffwd(self.ln2(x))
return out
class GPT(nn.Module):
def __init__(self, embedding_dim, vocab_size, context_size):
super(GPT, self).__init__()
self.embedding_dim = embedding_dim
self.output_dim = vocab_size
self.context_size = context_size
NUM_HEADS = 4
NUM_LAYERS = 4
# Initialize layers
self.tok_embed = nn.Embedding(vocab_size, embedding_dim)
self.pos_embed = nn.Embedding(context_size, embedding_dim)
mask = torch.tril(torch.ones(
self.context_size, self.context_size)).bool()
mask = ~mask
self.register_buffer("mask", mask)
self.blocks = nn.Sequential(
*[Block(embed_dim=embedding_dim, num_heads=NUM_HEADS, mask=mask, dropout=0.2) for _ in range(NUM_LAYERS)]
)
self.ln_f = nn.LayerNorm(self.embedding_dim)
# Final feed-forward layer from embeddings
self.ffwd = nn.Linear(
embedding_dim, out_features=vocab_size, bias=False)
def forward(self, x):
tok_embed = self.tok_embed(x)
pos_embed = self.pos_embed(
torch.arange(0, self.context_size).to(x.device)
)
x = tok_embed + pos_embed
x = self.blocks(x)
x = self.ln_f(x)
logits = self.ffwd(x)
return (logits)
def infer(self, x):
with torch.no_grad():
self.eval()
res = self.forward(x)
return (res)
def num_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def load_checkpoint(model, optimizer, path, device=torch.device('cuda')):
"""
Loads a saved checkpoint file into the model and optimizer.
Args:
model (nn.Module): The PyTorch model to load the checkpoint into.
optimizer (torch.optim.Optimizer): The PyTorch optimizer to load the checkpoint into.
path (str): The path to the saved checkpoint file.
Returns:
Tuple[nn.Module, torch.optim.Optimizer, int]: The model and optimizer, loaded with the checkpoint state.
"""
checkpoint = torch.load(path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return (model, optimizer, checkpoint['steps'])
def save_checkpoint(model, optimizer, path, steps):
"""
Saves a checkpoint of the model and optimizer to disk.
Args:
model (nn.Module): The PyTorch model to save the checkpoint of.
optimizer (torch.optim.Optimizer): The PyTorch optimizer to save the checkpoint of.
path (str): The path to save the checkpoint file.
steps (int): The number of training steps that have been completed.
Returns:
None
"""
torch.save({
'steps': steps,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
def compute_loss(model, criterion, x, y):
logits = model(x)
B, C, V = logits.shape
logits = logits.view(B*C, V)
y = y.view(B*C)
loss = F.cross_entropy(logits, y.long())
return loss
def print_model_devices(model):
print("Model Parameters:")
for name, param in model.named_parameters():
print(f"{name}: {param.device}")
print("\nModel Buffers:")
for name, buffer in model.named_buffers():
print(f"{name}: {buffer.device}")
def train(model, optimizer, config: Type[GPTConfig], global_step):
model = model.to(device)
criterion = F.cross_entropy
train_dataset, val_dataset, _ = load_dataset(
config.training_data_path, model.context_size)
train_dataloader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=4
)
val_dataloader = DataLoader(
val_dataset, batch_size=512, num_workers=4, shuffle=True)
model.train()
EPOCHS = 1
STEPS = config.num_steps
VAL_INTERVAL = 100
writer = SummaryWriter()
step = 0
for epoch in range(EPOCHS):
for data, target in train_dataloader:
data = data.to(device)
target = target.to(device)
loss = compute_loss(model, criterion, data, target)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
writer.add_scalar(
"Loss/train", loss.cpu().detach().numpy(), global_step)
global_step += 1
if step % VAL_INTERVAL == 0:
total_loss = 0
total_samples = 0
with torch.no_grad():
model.eval()
for x, y in val_dataloader:
x = x.to(device)
y = y.to(device)
batch_loss = compute_loss(model, criterion, x, y)
total_loss += batch_loss.item() * 512
total_samples += 512
if total_samples > 10:
break
model.train()
average_loss = total_loss / total_samples
print(f"Step {step}; loss: {average_loss}")
writer.add_scalar("Loss/val", average_loss, global_step)
step += 1
if step >= STEPS:
break
writer.close()
def evaluate_model(model, val_dataset, block_size=512, max_samples=100000):
model.eval()
total_loss = 0.0
total_samples = 0
criterion = F.cross_entropy
val_dataloader = DataLoader(
val_dataset, batch_size=block_size, num_workers=4)
with torch.no_grad():
for inputs, targets in val_dataloader:
inputs = inputs.to(device)
targets = targets.to(device)
batch_loss = compute_loss(model, criterion, inputs, targets)
total_loss += batch_loss.item() * inputs.size(0)
total_samples += inputs.size(0)
if total_samples > max_samples:
break
average_loss = total_loss / total_samples
return average_loss
def generate(model, config, prompt, gen_length, temp=1, top_k=10, top_p=None, device="cuda"):
g_cuda = torch.Generator(device=device)
contexts = torch.tensor(encode_text(prompt), dtype=torch.int32).to(device)
model.eval()
for i in range(gen_length):
transform = nn.LogSoftmax(1)
x = contexts[-config.context_size:]
if x.size(0) < config.context_size:
x = F.pad(x, (config.context_size - x.size(0), 0),
"constant", 0).unsqueeze(0) # B*T
else:
x = x.unsqueeze(0)
preds = model.infer(x)
preds = preds.squeeze(0)
preds = preds / temp
probs = F.softmax(preds, dim=-1)
if top_p is not None:
# Apply top-p
sorted_probs, sorted_indices = torch.sort(
probs[-1, :], descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# find cutoff
idx_top_p = (cumulative_probs < top_p).sum().item()
top_probs = sorted_probs[:idx_top_p]
top_indices = sorted_indices[:idx_top_p]
# Null case
if top_probs.size(0) == 0:
top_probs = sorted_probs[:1]
top_indices = sorted_indices[:1]
next_char = torch.multinomial(
top_probs, num_samples=1, generator=g_cuda)
next_char = top_indices[next_char]
elif top_k is not None:
top_k_probs, top_k_indices = torch.topk(probs[-1, :], k=top_k)
next_char = torch.multinomial(
top_k_probs, num_samples=1, generator=g_cuda)
next_char = top_k_indices[next_char]
else:
next_char = torch.multinomial(
probs, num_samples=1, generator=g_cuda)
contexts = torch.cat((contexts, next_char), dim=0)
print(decode_text(next_char.cpu().numpy())[-1], end="")
return ("".join(decode_text(contexts.cpu().numpy())))
def main():
# Parse the command-line arguments
args = parser.parse_args()
# args_is_empty = all(value is None for value in vars(args).values())
config = GPTConfig(args.config)
# Create the GPT model
model = GPT(
vocab_size=config.vocab_size,
context_size=config.context_size,
embedding_dim=config.embedding_dim
)
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
if args.gui or args.command is None:
load_checkpoint(model, optimizer, config.restore_path)
demo = gr.Interface(
fn=lambda *args: generate(model, config, *args),
inputs=[
gr.Textbox(lines=2, placeholder="Prompt here..."),
gr.Number(precision=0, value=256),
gr.Number(value=0.8),
gr.Slider(maximum=128, value=10),
gr.Slider(maximum=1, value=1)
],
outputs="text",
title="Shakespeare-GPT",
description="Putting theater kids out of their nonexistent jobs since 2023"
)
demo.launch()
elif args.command == "train":
if args.load_from_restore:
_, _, global_steps = load_checkpoint(model, optimizer, path)
else:
global_steps = 0
train(model, optimizer, config, global_step=global_steps)
# Persist model
timestamp = datetime.datetime.now().isoformat()
checkpoint_name = f"model_{timestamp}.pt"
save_checkpoint(model, optimizer, path=os.path.join(config.save_folder, checkpoint_name),
steps=global_steps+config.num_steps)
elif args.command == "eval":
_, _, test_dataset = load_dataset(
config.training_data_path, model.context_size)
evaluate_model(model, test_dataset)
elif args.command == "infer":
load_checkpoint(model, optimizer, config.restore_path)
prompt = args.text
generated_text = generate(model, config, prompt, args.length,
temp=config.temp, top_k=config.top_k, top_p=config.top_p)
print(generated_text)
if __name__ == "__main__":
main()