Final / train.py
Lepish's picture
Update train.py
848cfa0 verified
import os, pickle, json, torch
import torch.nn as nn
from model import GPT, GPTConfig
# Load both original and extra data
with open("data/ai_gf/input.txt", "r", encoding="utf-8") as f1, \
open("data/ai_gf/input_extra.txt", "r", encoding="utf-8") as f2:
text = f1.read() + "\n\n" + f2.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])
with open("meta.pkl", "wb") as f:
pickle.dump({'stoi': stoi, 'itos': itos}, f)
config = {
"vocab_size": vocab_size,
"block_size": 64,
"n_layer": 4,
"n_head": 4,
"n_embd": 128,
"dropout": 0.0,
"bias": False
}
with open("config.json", "w") as f:
json.dump(config, f)
data = torch.tensor(encode(text), dtype=torch.long)
gpt_config = GPTConfig(**config)
model = GPT(gpt_config)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 4
steps = 5000
print("Training on CPU...")
for step in range(steps):
ix = torch.randint(len(data) - config["block_size"], (batch_size,))
x = torch.stack([data[i:i+config["block_size"]] for i in ix])
y = torch.stack([data[i+1:i+1+config["block_size"]] for i in ix])
logits, _ = model(x)
loss = nn.functional.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 10 == 0:
print(f"Step {step}/{steps}, Loss: {loss.item():.4f}")
torch.save(model.state_dict(), "checkpoint.pt")
print("Training complete.")