|
import json
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from tokenizers import Tokenizer
|
|
from tqdm import tqdm
|
|
import os
|
|
import re
|
|
from collections import Counter
|
|
import multiprocessing
|
|
from torch.utils.data import random_split
|
|
|
|
multiprocessing.set_start_method("spawn", force=True)
|
|
|
|
class ChatDataset(Dataset):
|
|
def __init__(self, data, tokenizer, block_size=64):
|
|
self.tokenizer = tokenizer
|
|
self.block_size = block_size
|
|
self.data = self.tokenize_data(data)
|
|
|
|
def tokenize_data(self, data):
|
|
chunks = []
|
|
with open(data, "r", encoding="utf-8") as f:
|
|
for d in f:
|
|
line = json.loads(d.strip())
|
|
|
|
text = "^User: " + line["instruction"].strip() + " MiniGPT: " + line["output"].strip() + " <END>"
|
|
encoding = self.tokenizer.encode(text)
|
|
tokens = encoding.ids
|
|
|
|
|
|
|
|
if len(tokens) < self.block_size:
|
|
print(f"Skipping short example (length {len(tokens)} < block_size {self.block_size}): {text[:50]}...")
|
|
continue
|
|
|
|
|
|
|
|
|
|
stride = 1
|
|
for i in range(0, len(tokens) - self.block_size + 1, stride):
|
|
chunk = tokens[i:i + self.block_size]
|
|
if len(chunk) == self.block_size:
|
|
chunks.append(chunk)
|
|
print(f"Dataset created with {len(chunks)} total training chunks.")
|
|
return chunks
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
chunk = self.data[idx]
|
|
x = torch.tensor(chunk[:-1], dtype=torch.long)
|
|
y = torch.tensor(chunk[1:], dtype=torch.long)
|
|
return x, y
|
|
|
|
|
|
class MiniBPETokenizr:
|
|
def __init__(self):
|
|
self.stoi = {}
|
|
self.itos = {}
|
|
self.vocab_size = 0
|
|
|
|
def tokenize(self, text):
|
|
text = text.lower().strip()
|
|
words = re.findall(r"[a-zA-Z0-9]+|[^\w\s]", text)
|
|
return [list(w) + ['</w>'] if w.isalnum() else [w] for w in words]
|
|
|
|
def get_stats(self, corpus):
|
|
pairs = Counter()
|
|
for tokens in corpus:
|
|
for i in range(len(tokens) - 1):
|
|
pairs[(tokens[i], tokens[i + 1])] += 1
|
|
return pairs
|
|
|
|
def merge_vocab(self, corpus, pair_to_merge):
|
|
bigram = re.escape(' '.join(pair_to_merge))
|
|
pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
|
|
merged = []
|
|
for tokens in corpus:
|
|
token_str = ' '.join(tokens)
|
|
token_str = pattern.sub(''.join(pair_to_merge), token_str)
|
|
merged.append(token_str.split())
|
|
return merged
|
|
|
|
def train(self, texts, merge_limit=1000):
|
|
corpus = [sum(self.tokenize(t), []) for t in texts]
|
|
merges_done = 0
|
|
loop = tqdm(total=merge_limit, desc="Training BPE")
|
|
|
|
while merges_done < merge_limit:
|
|
pairs = self.get_stats(corpus)
|
|
if not pairs:
|
|
break
|
|
best = max(pairs, key=pairs.get)
|
|
corpus = self.merge_vocab(corpus, best)
|
|
merges_done += 1
|
|
loop.update(1)
|
|
|
|
vocab = set(tok for seq in corpus for tok in seq)
|
|
vocab.update(["<PAD>", "<UNK>", "<END>", "^user:", "minigpt:"])
|
|
self.stoi = {tok: i for i, tok in enumerate(sorted(vocab))}
|
|
self.itos = {i: tok for tok, i in self.stoi.items()}
|
|
self.vocab_size = len(self.stoi)
|
|
|
|
def encode(self, text):
|
|
tokens = sum(self.tokenize(text), [])
|
|
output = []
|
|
i = 0
|
|
while i < len(tokens):
|
|
j = len(tokens)
|
|
while j > i:
|
|
candidate = ''.join(tokens[i:j])
|
|
if candidate in self.stoi:
|
|
output.append(self.stoi[candidate])
|
|
i = j
|
|
break
|
|
j -= 1
|
|
else:
|
|
output.append(self.stoi.get("<UNK>", 1))
|
|
i += 1
|
|
return output
|
|
|
|
def decode(self, token_ids):
|
|
tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
|
|
text = ' '.join(t.replace('</w>', '') for t in tokens if t not in {"<PAD>", "<END>", "<UNK>"})
|
|
text = re.sub(r'\s([?.!,:;])', r'\1', text)
|
|
return text.strip()
|
|
|
|
def save(self, path):
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump({"stoi": self.stoi, "itos": self.itos}, f)
|
|
|
|
def load(self, path):
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
self.stoi = {k: int(v) for k, v in data["stoi"].items()}
|
|
self.itos = {int(v): k for k, v in self.stoi.items()}
|
|
self.vocab_size = len(self.stoi)
|
|
|
|
class SimpleTokenizr:
|
|
def __init__(self):
|
|
self.stoi = {}
|
|
self.itos = {}
|
|
|
|
def tokenize(self, text):
|
|
return re.findall(r"[a-zA-Z']+|\d+|[^\w\s]", text.lower())
|
|
|
|
def train(self, texts):
|
|
vocab = set()
|
|
for text in texts:
|
|
tokens = self.tokenize(text)
|
|
vocab.update(tokens)
|
|
vocab.update(["<PAD>", "<UNK>", "<END>", "^user :", "minigpt :", "MiniGPT :", ":"])
|
|
sorted_vocab = sorted(vocab)
|
|
self.stoi = {token: idx for idx, token in enumerate(sorted_vocab)}
|
|
self.itos = {idx: token for token, idx in self.stoi.items()}
|
|
|
|
def encode(self, text):
|
|
tokens = self.tokenize(text)
|
|
return [self.stoi.get(tok, self.stoi["<UNK>"]) for tok in tokens] + [self.stoi["<END>"]]
|
|
|
|
def decode(self, token_ids):
|
|
tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
|
|
clean_tokens = [tok for tok in tokens if tok not in {"<PAD>", "<UNK>", "<END>"}]
|
|
text = ''
|
|
for i, tok in enumerate(clean_tokens):
|
|
if re.match(r"[.,!?;:]", tok):
|
|
text += tok
|
|
elif i > 0:
|
|
text += ' ' + tok
|
|
else:
|
|
text += tok
|
|
return text.strip().capitalize()
|
|
|
|
def save(self, path):
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump({"stoi": self.stoi, "itos": self.itos}, f)
|
|
|
|
def load(self, path):
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
self.stoi = {k: int(v) for k, v in data["stoi"].items()}
|
|
self.itos = {int(k): v for v, k in self.stoi.items()}
|
|
|
|
def __len__(self):
|
|
return len(self.stoi)
|
|
|
|
@property
|
|
def vocab_size(self):
|
|
return len(self.stoi)
|
|
|
|
def validate(model, dataloader, device):
|
|
model.eval()
|
|
total_loss, correct, total = 0, 0, 0
|
|
with torch.no_grad():
|
|
for x, y in dataloader:
|
|
x, y = x.to(device), y.to(device)
|
|
logits = model(x)
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
|
total_loss += loss.item()
|
|
|
|
preds = torch.argmax(logits, dim=-1)
|
|
correct += (preds == y).sum().item()
|
|
total += y.numel()
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
accuracy = 100 * correct / total
|
|
return avg_loss, accuracy
|
|
|
|
|
|
def train(model, dataset, tokenizer, epochs, filepathh, start_epoch=0, start_step=0, learning_rate=5e-5):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
|
|
|
|
val_size = int(0.1 * len(dataset))
|
|
train_size = len(dataset) - val_size
|
|
train_set, val_set = random_split(dataset, [train_size, val_size])
|
|
|
|
|
|
|
|
|
|
train_loader = DataLoader(train_set, batch_size=1, shuffle=True, num_workers=0)
|
|
val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=0)
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
|
|
|
checkpoint_path = "./trained-mini-gpt/checkpoint-mini-gpt.pth"
|
|
if os.path.exists(checkpoint_path):
|
|
checkpoint = torch.load(checkpoint_path)
|
|
if "model_state_dict" in checkpoint:
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
start_epoch = checkpoint["epoch"]
|
|
start_step = checkpoint["step"]
|
|
else:
|
|
model.load_state_dict(checkpoint)
|
|
else:
|
|
print("π Starting from scratch.")
|
|
|
|
total_steps = start_step
|
|
|
|
for epoch in range(start_epoch, epochs):
|
|
model.train()
|
|
total_loss, correct, total = 0, 0, 0
|
|
|
|
loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}")
|
|
for step, (x, y) in loop:
|
|
x, y = x.to(device), y.to(device)
|
|
|
|
|
|
|
|
if step % 1 == 0:
|
|
input_ids_cpu = x[0].cpu().tolist()
|
|
target_ids_cpu = y[0].cpu().tolist()
|
|
|
|
decoded_input = tokenizer.decode(input_ids_cpu)
|
|
decoded_target = tokenizer.decode(target_ids_cpu)
|
|
|
|
print(f"\n--- Epoch {epoch+1}, Step {step} ---")
|
|
print(f"Input (decoded): '{decoded_input}'")
|
|
print(f"Target (decoded): '{decoded_target}'")
|
|
|
|
|
|
logits = model(x)
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
preds = torch.argmax(logits, dim=-1)
|
|
correct += (preds == y).sum().item()
|
|
total += y.numel()
|
|
acc = 100 * correct / total
|
|
|
|
loop.set_postfix(loss=loss.item(), acc=acc)
|
|
|
|
|
|
if step % 1 == 0:
|
|
predicted_logits_cpu = logits[0, :, :].cpu()
|
|
predicted_ids = torch.argmax(predicted_logits_cpu, dim=-1).tolist()
|
|
decoded_predicted = tokenizer.decode(predicted_ids)
|
|
print(f"Predicted (decoded): '{decoded_predicted}'")
|
|
print(f"Current Batch Loss: {loss.item():.4f}")
|
|
print(f"Current Batch Accuracy: {100 * (preds == y).float().mean().item():.2f}%")
|
|
|
|
|
|
|
|
val_loss, val_acc = validate(model, val_loader, device)
|
|
print(f"β
Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%")
|
|
|
|
|
|
torch.save({
|
|
"model_state_dict": model.state_dict(),
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
"epoch": epoch,
|
|
"step": total_steps
|
|
}, checkpoint_path)
|
|
|
|
torch.save(model.state_dict(), "./trained-mini-gpt/mini-gpt.pth")
|
|
print("π Training complete.") |