gpt / train.py
kevinwang676's picture
Add files using upload-large-folder tool
955f8d8 verified
import argparse
import os
import sys
import shutil
import random
import numpy as np
import time
import copy
import math
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
import transformers
from transformers import GPT2TokenizerFast
# ---------------------------
# Utility masks & helpers
# ---------------------------
def subsequent_mask(size):
"""Mask out subsequent positions for autoregressive decoding."""
attn_shape = (1, size, size)
mask = torch.triu(torch.ones(attn_shape), diagonal=1).bool()
return mask
def read_corpus(filename, tokenizer):
"""Tokenise a plain‑text corpus into a single long id sequence."""
seq = []
with open(filename, "rt") as f:
for line in f:
line = line.rstrip("\n")
tokens = tokenizer(line)
seq.extend(tokens["input_ids"])
return seq
# ---------------------------
# Embedding & positional code
# ---------------------------
class Embedder(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.d_model = d_model
self.embed = nn.Embedding(vocab_size, d_model)
def forward(self, x):
return self.embed(x.long())
class PositionalEncoder(nn.Module):
def __init__(self, d_model, max_seq_len: int = 4096, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x = x * math.sqrt(self.d_model)
seq_len = x.size(1)
x = x + self.pe[:, :seq_len]
return self.dropout(x)
class Norm(nn.Module):
"""Layer‑norm with learnable gain/bias (identical to nn.LayerNorm but explicit)."""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.size = d_model
self.alpha = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
return self.alpha * (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
# ---------------------------
# Attention (Euclidean metric)
# ---------------------------
def euclidean_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, d_k: int, mask=None, dropout=None):
"""Scaled Euclidean‑distance attention.
Attention weights are computed from *negative scaled squared Euclidean distances*:
score_{ij} = -||q_i - k_j||^2 / sqrt(d_k)
A softmax over the key dimension then yields the usual attention distribution.
"""
# q, k, v: (bs, h, len, d_k)
# Compute ||q||^2 and ||k||^2 terms
q_norm = (q ** 2).sum(dim=-1, keepdim=True) # (bs, h, len_q, 1)
k_norm = (k ** 2).sum(dim=-1).unsqueeze(-2) # (bs, h, 1, len_k)
# Pairwise squared distances via (a-b)^2 = a^2 + b^2 - 2ab
scores = q_norm + k_norm - 2 * torch.matmul(q, k.transpose(-2, -1)) # (bs, h, len_q, len_k)
scores = -scores / math.sqrt(d_k) # negate & scale so that *smaller distance => larger score*
if mask is not None:
mask = mask.unsqueeze(1) # broadcast across heads
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=-1)
if dropout is not None:
attn = dropout(attn)
output = torch.matmul(attn, v)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, heads: int, d_model: int, dropout: float = 0.1):
super().__init__()
assert d_model % heads == 0, "d_model must be divisible by heads"
self.d_k = d_model // heads
self.h = heads
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
bs = q.size(0)
# project and split multi‑head
k = self.k_linear(k).view(bs, -1, self.h, self.d_k).transpose(1, 2) # (bs, h, len, d_k)
q = self.q_linear(q).view(bs, -1, self.h, self.d_k).transpose(1, 2)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k).transpose(1, 2)
# Euclidean attention
scores = euclidean_attention(q, k, v, self.d_k, mask, self.dropout)
# merge heads
concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.h * self.d_k)
return self.out(concat)
# ---------------------------
# Feed‑forward & decoder
# ---------------------------
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int = 2048, dropout: float = 0.1):
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.linear_2(self.dropout(F.relu(self.linear_1(x))))
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, heads: int, dropout: float = 0.1):
super().__init__()
self.norm_1 = Norm(d_model)
self.norm_2 = Norm(d_model)
self.attn = MultiHeadAttention(heads, d_model, dropout)
self.ff = FeedForward(d_model, dropout=dropout)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x, trg_mask):
x2 = self.norm_1(x)
x = x + self.dropout_1(self.attn(x2, x2, x2, trg_mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.ff(x2))
return x
class Decoder(nn.Module):
def __init__(self, vocab_size: int, d_model: int, N: int, heads: int, dropout: float):
super().__init__()
self.embed = Embedder(vocab_size, d_model)
self.pe = PositionalEncoder(d_model, dropout=dropout)
self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N)
self.norm = Norm(d_model)
def forward(self, x, trg_mask):
x = self.embed(x)
x = self.pe(x)
for layer in self.layers:
x = layer(x, trg_mask)
return self.norm(x)
class GPT2LM(nn.Module):
def __init__(self, vocab_size: int, d_model: int, N: int, heads: int, dropout: float, tie_weights: bool = False):
super().__init__()
self.decoder = Decoder(vocab_size, d_model, N, heads, dropout)
self.out = nn.Linear(d_model, vocab_size)
if tie_weights:
self.out.weight = self.decoder.embed.embed.weight
print("✅ Tied embeddings enabled.")
def forward(self, x, mask):
return self.out(self.decoder(x, mask))
# ---------------------------
# Data batcher
# ---------------------------
def batchify(data, batch_size, seq_len):
nbatch = len(data) // batch_size
data = torch.tensor(data[: nbatch * batch_size], dtype=torch.long)
data = data.view(batch_size, -1)
for i in range(0, data.size(1) - 1, seq_len):
seq_len_i = min(seq_len, data.size(1) - 1 - i)
src = data[:, i : i + seq_len_i]
tgt = data[:, i + 1 : i + 1 + seq_len_i]
yield src, tgt
# ---------------------------
# Train / eval loops
# ---------------------------
def train_model(model, opt):
print("Starting training (Euclidean attention)…")
model.train()
train_ppls, valid_ppls = [], []
for epoch in range(opt.epochs):
total_loss, batches = 0.0, 0
for src, tgt in batchify(opt.train, opt.batchsize, opt.seqlen):
src, tgt = src.to(opt.device), tgt.to(opt.device)
mask = subsequent_mask(src.size(1)).to(opt.device)
output = model(src, mask)
loss = F.cross_entropy(output.view(-1, opt.vocab_size), tgt.reshape(-1), ignore_index=opt.src_pad)
opt.optimizer.zero_grad()
loss.backward()
opt.optimizer.step()
total_loss += loss.item()
batches += 1
avg_loss = total_loss / batches
train_ppl = math.exp(avg_loss)
train_ppls.append(train_ppl)
print(f"Epoch {epoch+1}/{opt.epochs} • Train PPL: {train_ppl:.2f}")
valid_ppl = evaluate(model, opt.valid, opt, tag=f"valid‑e{epoch+1}")
valid_ppls.append(valid_ppl)
# --- bookkeeping ---
dir_name = os.path.join("saved", opt.dir_name)
os.makedirs(dir_name, exist_ok=True)
torch.save(model.state_dict(), os.path.join(dir_name, "gpt2lm_euclid.pth"))
plt.plot(range(1, opt.epochs + 1), train_ppls, label="Train PPL")
plt.plot(range(1, opt.epochs + 1), valid_ppls, label="Valid PPL")
plt.xlabel("Epoch"); plt.ylabel("Perplexity"); plt.title("Euclidean‑Attention GPT‑2 on WikiText‑2")
plt.legend()
plt.savefig(os.path.join(dir_name, "learning_curve.png"))
plt.close()
with open(os.path.join(dir_name, "perplexity_log.txt"), "w") as f:
for i in range(opt.epochs):
f.write(f"Epoch {i+1}: Train {train_ppls[i]:.2f} Valid {valid_ppls[i]:.2f}\n")
def evaluate(model, data, opt, tag="valid"):
model.eval()
total_loss, batches = 0.0, 0
with torch.no_grad():
for src, tgt in batchify(data, opt.batchsize, opt.seqlen):
src, tgt = src.to(opt.device), tgt.to(opt.device)
mask = subsequent_mask(src.size(1)).to(opt.device)
output = model(src, mask)
loss = F.cross_entropy(output.view(-1, opt.vocab_size), tgt.reshape(-1), ignore_index=opt.src_pad)
total_loss += loss.item()
batches += 1
ppl = math.exp(total_loss / batches)
print(f"{tag.capitalize()} PPL: {ppl:.2f}")
model.train()
return ppl
# ---------------------------
# Main entry
# ---------------------------
def main():
random.seed(10)
parser = argparse.ArgumentParser()
parser.add_argument("-no_cuda", action="store_true")
parser.add_argument("-epochs", type=int, default=20)
parser.add_argument("-d_model", type=int, default=512)
parser.add_argument("-n_layers", type=int, default=6)
parser.add_argument("-heads", type=int, default=8)
parser.add_argument("-dropout", type=float, default=0.1)
parser.add_argument("-batchsize", type=int, default=1)
parser.add_argument("-lr", type=float, default=1e-5)
parser.add_argument("-seqlen", type=int, default=512)
parser.add_argument("-tied", type=int, default=1)
parser.add_argument("-dir_name", type=str, default="model_euclid")
opt = parser.parse_args()
opt.device = torch.device("cuda:0" if (not opt.no_cuda and torch.cuda.is_available()) else "cpu")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
opt.train = read_corpus("wiki2.train.txt", tokenizer)
opt.valid = read_corpus("wiki2.valid.txt", tokenizer)
opt.test = read_corpus("wiki2.test.txt", tokenizer)
opt.vocab_size = 50257
opt.src_pad = opt.trg_pad = 0
model = GPT2LM(opt.vocab_size, opt.d_model, opt.n_layers, opt.heads, opt.dropout, tie_weights=(opt.tied == 1)).to(opt.device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.1f}M")
opt.optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
train_model(model, opt)
evaluate(model, opt.test, opt, tag="test")
if __name__ == "__main__":
main()