|
import argparse |
|
import datetime |
|
import gc |
|
import json |
|
import math |
|
import os |
|
from typing import Type |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='GPT CLI') |
|
parser.add_argument('--gui', action='store_true', help='Enable Gradio UI mode') |
|
parser.add_argument('--config', default='./gpt_config.json', |
|
help='Path to the config file') |
|
subparsers = parser.add_subparsers(dest='command', help='Choose a command') |
|
|
|
|
|
train_parser = subparsers.add_parser('train', help='Train the model') |
|
train_parser.add_argument('--load-from-restore', action='store_true', |
|
help='Load from restore path instead of training from scratch') |
|
|
|
|
|
eval_parser = subparsers.add_parser('eval', help='Evaluate the model') |
|
eval_parser.add_argument('--data', default='./data/evaluation_data.txt', |
|
help='Path to the evaluation data file') |
|
|
|
|
|
infer_parser = subparsers.add_parser( |
|
'infer', help='Generate text from the model') |
|
infer_parser.add_argument('--text', type=str, required=True, |
|
help='Input text for generating continuation') |
|
infer_parser.add_argument('--length', type=int, |
|
default=100, help='Number of characters to generate') |
|
|
|
torch.manual_seed(1337) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class GPTConfig: |
|
def __init__(self, config_file_path): |
|
with open(config_file_path, 'r') as f: |
|
config = json.load(f) |
|
|
|
|
|
architecture_config = config['architecture'] |
|
self.embedding_dim = architecture_config['embedding_dim'] |
|
self.vocab_size = architecture_config['vocab_size'] |
|
self.context_size = architecture_config['context_size'] |
|
self.num_heads = architecture_config['num_heads'] |
|
self.num_layers = architecture_config['num_layers'] |
|
|
|
|
|
training_config = config['training'] |
|
self.batch_size = training_config['batch_size'] |
|
self.training_data_path = training_config['training_data_path'] |
|
self.save_folder = training_config['save_folder'] |
|
self.learning_rate = training_config['learning_rate'] |
|
self.num_steps = training_config['num_steps'] |
|
self.val_interval = training_config['val_interval'] |
|
|
|
|
|
generation_config = config['generation'] |
|
self.top_k = generation_config['top_k'] |
|
self.top_p = generation_config['top_p'] |
|
self.temp = generation_config['temp'] |
|
|
|
|
|
self.restore_path = config['restore_path'] |
|
|
|
|
|
def encode_text(text): |
|
|
|
|
|
return ([ord(t) for t in text]) |
|
|
|
|
|
def decode_text(indices): |
|
return ([chr(x) for x in indices]) |
|
|
|
|
|
class TextDataset(Dataset): |
|
def __init__(self, data_tensor, context_size): |
|
self.data_tensor = data_tensor |
|
self.context_size = context_size |
|
|
|
def __len__(self): |
|
return len(self.data_tensor) - self.context_size |
|
|
|
def __getitem__(self, index): |
|
x = self.data_tensor[index:index + self.context_size] |
|
y = self.data_tensor[index + 1:index + self.context_size + 1] |
|
|
|
return x, y |
|
|
|
|
|
def load_dataset(data_path, context_size): |
|
with open(data_path, 'r', encoding='utf-8') as f: |
|
text = f.read() |
|
|
|
|
|
data = torch.tensor(encode_text(text), dtype=torch.int32) |
|
|
|
test_split_idx = int(0.8 * len(data)) |
|
val_split_idx = int(0.9 * len(data)) |
|
train_data = data[:test_split_idx] |
|
test_data = data[test_split_idx:val_split_idx] |
|
val_data = data[val_split_idx:] |
|
|
|
|
|
train_dataset = TextDataset(train_data, context_size) |
|
test_dataset = TextDataset(test_data, context_size) |
|
val_dataset = TextDataset(test_data, context_size) |
|
return ((train_dataset, val_dataset, test_dataset)) |
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, device=None, dtype=None): |
|
super(MultiheadAttention, self).__init__() |
|
|
|
|
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.d_k = embed_dim // num_heads |
|
|
|
self.Q = nn.Linear(embed_dim, embed_dim, bias=False) |
|
self.K = nn.Linear(embed_dim, embed_dim, bias=False) |
|
self.V = nn.Linear(embed_dim, embed_dim, bias=False) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
|
|
|
def forward(self, query, key, value, attn_mask=None): |
|
batch_size = query.size(0) |
|
|
|
|
|
q = self.Q(query) |
|
k = self.K(key) |
|
v = self.V(value) |
|
|
|
|
|
|
|
q = q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) |
|
|
|
k = k.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) |
|
|
|
v = v.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) |
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / \ |
|
math.sqrt(self.d_k) |
|
|
|
|
|
if attn_mask is not None: |
|
scores = scores.masked_fill(attn_mask, float('-inf')) |
|
|
|
|
|
attn = F.softmax(scores, dim=-1) |
|
attn = self.dropout(attn) |
|
out = attn @ v |
|
|
|
|
|
|
|
out = out.transpose(1, 2).contiguous().view( |
|
batch_size, -1, self.embed_dim) |
|
|
|
out = self.out_proj(out) |
|
return ((out, None)) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, embed_dim, dropout): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(embed_dim, 4 * embed_dim), |
|
nn.GELU(), |
|
nn.Linear(4 * embed_dim, embed_dim), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x): |
|
return (self.net(x)) |
|
|
|
|
|
class Block(nn.Module): |
|
"""Self-attention""" |
|
|
|
def __init__(self, embed_dim, num_heads, mask, dropout=0.2): |
|
super(Block, self).__init__() |
|
self.register_buffer("mask", mask) |
|
self.head = MultiheadAttention( |
|
embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) |
|
|
|
self.ffwd = FeedForward(embed_dim=embed_dim, dropout=dropout) |
|
self.ln1 = nn.LayerNorm(embed_dim) |
|
self.ln2 = nn.LayerNorm(embed_dim) |
|
|
|
def forward(self, x): |
|
|
|
x = self.ln1(x) |
|
attn_output, _ = self.head(x, x, x, attn_mask=self.mask) |
|
x = x + attn_output |
|
out = x + self.ffwd(self.ln2(x)) |
|
return out |
|
|
|
|
|
class GPT(nn.Module): |
|
def __init__(self, embedding_dim, vocab_size, context_size): |
|
super(GPT, self).__init__() |
|
|
|
self.embedding_dim = embedding_dim |
|
self.output_dim = vocab_size |
|
self.context_size = context_size |
|
|
|
NUM_HEADS = 4 |
|
NUM_LAYERS = 4 |
|
|
|
|
|
self.tok_embed = nn.Embedding(vocab_size, embedding_dim) |
|
self.pos_embed = nn.Embedding(context_size, embedding_dim) |
|
|
|
mask = torch.tril(torch.ones( |
|
self.context_size, self.context_size)).bool() |
|
mask = ~mask |
|
self.register_buffer("mask", mask) |
|
|
|
self.blocks = nn.Sequential( |
|
*[Block(embed_dim=embedding_dim, num_heads=NUM_HEADS, mask=mask, dropout=0.2) for _ in range(NUM_LAYERS)] |
|
) |
|
|
|
self.ln_f = nn.LayerNorm(self.embedding_dim) |
|
|
|
self.ffwd = nn.Linear( |
|
embedding_dim, out_features=vocab_size, bias=False) |
|
|
|
def forward(self, x): |
|
tok_embed = self.tok_embed(x) |
|
pos_embed = self.pos_embed( |
|
torch.arange(0, self.context_size).to(x.device) |
|
) |
|
x = tok_embed + pos_embed |
|
|
|
x = self.blocks(x) |
|
x = self.ln_f(x) |
|
|
|
logits = self.ffwd(x) |
|
return (logits) |
|
|
|
def infer(self, x): |
|
with torch.no_grad(): |
|
self.eval() |
|
res = self.forward(x) |
|
return (res) |
|
|
|
def num_params(self): |
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
|
def load_checkpoint(model, optimizer, path, device=torch.device('cuda')): |
|
""" |
|
Loads a saved checkpoint file into the model and optimizer. |
|
|
|
Args: |
|
model (nn.Module): The PyTorch model to load the checkpoint into. |
|
optimizer (torch.optim.Optimizer): The PyTorch optimizer to load the checkpoint into. |
|
path (str): The path to the saved checkpoint file. |
|
|
|
Returns: |
|
Tuple[nn.Module, torch.optim.Optimizer, int]: The model and optimizer, loaded with the checkpoint state. |
|
""" |
|
checkpoint = torch.load(path, map_location=device) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
if optimizer is not None: |
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
return (model, optimizer, checkpoint['steps']) |
|
|
|
|
|
def save_checkpoint(model, optimizer, path, steps): |
|
""" |
|
Saves a checkpoint of the model and optimizer to disk. |
|
|
|
Args: |
|
model (nn.Module): The PyTorch model to save the checkpoint of. |
|
optimizer (torch.optim.Optimizer): The PyTorch optimizer to save the checkpoint of. |
|
path (str): The path to save the checkpoint file. |
|
steps (int): The number of training steps that have been completed. |
|
|
|
Returns: |
|
None |
|
""" |
|
torch.save({ |
|
'steps': steps, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
}, path) |
|
|
|
|
|
def compute_loss(model, criterion, x, y): |
|
logits = model(x) |
|
B, C, V = logits.shape |
|
logits = logits.view(B*C, V) |
|
y = y.view(B*C) |
|
loss = F.cross_entropy(logits, y.long()) |
|
return loss |
|
|
|
|
|
def print_model_devices(model): |
|
print("Model Parameters:") |
|
for name, param in model.named_parameters(): |
|
print(f"{name}: {param.device}") |
|
|
|
print("\nModel Buffers:") |
|
for name, buffer in model.named_buffers(): |
|
print(f"{name}: {buffer.device}") |
|
|
|
|
|
def train(model, optimizer, config: Type[GPTConfig], global_step): |
|
model = model.to(device) |
|
criterion = F.cross_entropy |
|
|
|
train_dataset, val_dataset, _ = load_dataset( |
|
config.training_data_path, model.context_size) |
|
|
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
batch_size=config.batch_size, |
|
shuffle=True, |
|
num_workers=4 |
|
) |
|
|
|
val_dataloader = DataLoader( |
|
val_dataset, batch_size=512, num_workers=4, shuffle=True) |
|
|
|
model.train() |
|
|
|
EPOCHS = 1 |
|
STEPS = config.num_steps |
|
VAL_INTERVAL = 100 |
|
|
|
writer = SummaryWriter() |
|
|
|
step = 0 |
|
|
|
for epoch in range(EPOCHS): |
|
for data, target in train_dataloader: |
|
data = data.to(device) |
|
target = target.to(device) |
|
|
|
loss = compute_loss(model, criterion, data, target) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
writer.add_scalar( |
|
"Loss/train", loss.cpu().detach().numpy(), global_step) |
|
global_step += 1 |
|
|
|
if step % VAL_INTERVAL == 0: |
|
total_loss = 0 |
|
total_samples = 0 |
|
|
|
with torch.no_grad(): |
|
model.eval() |
|
for x, y in val_dataloader: |
|
x = x.to(device) |
|
y = y.to(device) |
|
|
|
batch_loss = compute_loss(model, criterion, x, y) |
|
total_loss += batch_loss.item() * 512 |
|
total_samples += 512 |
|
if total_samples > 10: |
|
break |
|
|
|
model.train() |
|
average_loss = total_loss / total_samples |
|
|
|
print(f"Step {step}; loss: {average_loss}") |
|
writer.add_scalar("Loss/val", average_loss, global_step) |
|
|
|
step += 1 |
|
if step >= STEPS: |
|
break |
|
|
|
writer.close() |
|
|
|
|
|
def evaluate_model(model, val_dataset, block_size=512, max_samples=100000): |
|
model.eval() |
|
total_loss = 0.0 |
|
total_samples = 0 |
|
criterion = F.cross_entropy |
|
|
|
val_dataloader = DataLoader( |
|
val_dataset, batch_size=block_size, num_workers=4) |
|
with torch.no_grad(): |
|
for inputs, targets in val_dataloader: |
|
inputs = inputs.to(device) |
|
targets = targets.to(device) |
|
|
|
batch_loss = compute_loss(model, criterion, inputs, targets) |
|
total_loss += batch_loss.item() * inputs.size(0) |
|
total_samples += inputs.size(0) |
|
if total_samples > max_samples: |
|
break |
|
|
|
average_loss = total_loss / total_samples |
|
return average_loss |
|
|
|
|
|
def generate(model, config, prompt, gen_length, temp=1, top_k=10, top_p=None, device="cuda"): |
|
g_cuda = torch.Generator(device=device) |
|
contexts = torch.tensor(encode_text(prompt), dtype=torch.int32).to(device) |
|
|
|
model.eval() |
|
for i in range(gen_length): |
|
transform = nn.LogSoftmax(1) |
|
x = contexts[-config.context_size:] |
|
if x.size(0) < config.context_size: |
|
x = F.pad(x, (config.context_size - x.size(0), 0), |
|
"constant", 0).unsqueeze(0) |
|
else: |
|
x = x.unsqueeze(0) |
|
|
|
preds = model.infer(x) |
|
preds = preds.squeeze(0) |
|
preds = preds / temp |
|
probs = F.softmax(preds, dim=-1) |
|
|
|
if top_p is not None: |
|
|
|
sorted_probs, sorted_indices = torch.sort( |
|
probs[-1, :], descending=True) |
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
idx_top_p = (cumulative_probs < top_p).sum().item() |
|
top_probs = sorted_probs[:idx_top_p] |
|
top_indices = sorted_indices[:idx_top_p] |
|
|
|
if top_probs.size(0) == 0: |
|
top_probs = sorted_probs[:1] |
|
top_indices = sorted_indices[:1] |
|
|
|
next_char = torch.multinomial( |
|
top_probs, num_samples=1, generator=g_cuda) |
|
next_char = top_indices[next_char] |
|
elif top_k is not None: |
|
top_k_probs, top_k_indices = torch.topk(probs[-1, :], k=top_k) |
|
next_char = torch.multinomial( |
|
top_k_probs, num_samples=1, generator=g_cuda) |
|
next_char = top_k_indices[next_char] |
|
else: |
|
next_char = torch.multinomial( |
|
probs, num_samples=1, generator=g_cuda) |
|
|
|
contexts = torch.cat((contexts, next_char), dim=0) |
|
print(decode_text(next_char.cpu().numpy())[-1], end="") |
|
|
|
return ("".join(decode_text(contexts.cpu().numpy()))) |
|
|
|
|
|
def main(): |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
config = GPTConfig(args.config) |
|
|
|
model = GPT( |
|
vocab_size=config.vocab_size, |
|
context_size=config.context_size, |
|
embedding_dim=config.embedding_dim |
|
) |
|
model.to(device) |
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate) |
|
if args.gui or args.command is None: |
|
load_checkpoint(model, optimizer, config.restore_path) |
|
demo = gr.Interface( |
|
fn=lambda *args: generate(model, config, *args), |
|
inputs=[ |
|
gr.Textbox(lines=2, placeholder="Prompt here..."), |
|
gr.Number(precision=0, value=256), |
|
gr.Number(value=0.8), |
|
gr.Slider(maximum=128, value=10), |
|
gr.Slider(maximum=1, value=1) |
|
], |
|
outputs="text", |
|
title="Shakespeare-GPT", |
|
description="Putting theater kids out of their nonexistent jobs since 2023" |
|
) |
|
|
|
demo.launch() |
|
elif args.command == "train": |
|
if args.load_from_restore: |
|
_, _, global_steps = load_checkpoint(model, optimizer, path) |
|
else: |
|
global_steps = 0 |
|
|
|
train(model, optimizer, config, global_step=global_steps) |
|
|
|
|
|
timestamp = datetime.datetime.now().isoformat() |
|
checkpoint_name = f"model_{timestamp}.pt" |
|
save_checkpoint(model, optimizer, path=os.path.join(config.save_folder, checkpoint_name), |
|
steps=global_steps+config.num_steps) |
|
elif args.command == "eval": |
|
_, _, test_dataset = load_dataset( |
|
config.training_data_path, model.context_size) |
|
evaluate_model(model, test_dataset) |
|
elif args.command == "infer": |
|
load_checkpoint(model, optimizer, config.restore_path) |
|
prompt = args.text |
|
generated_text = generate(model, config, prompt, args.length, |
|
temp=config.temp, top_k=config.top_k, top_p=config.top_p) |
|
print(generated_text) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|