import json import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from tokenizers import Tokenizer from tqdm import tqdm import os import re from collections import Counter import multiprocessing from torch.utils.data import random_split multiprocessing.set_start_method("spawn", force=True) class ChatDataset(Dataset): def __init__(self, data, tokenizer, block_size=64): self.tokenizer = tokenizer self.block_size = block_size self.data = self.tokenize_data(data) def tokenize_data(self, data): chunks = [] with open(data, "r", encoding="utf-8") as f: for d in f: line = json.loads(d.strip()) # Fix duplicated instruction text = "^User: " + line["instruction"].strip() + " MiniGPT: " + line["output"].strip() + " " encoding = self.tokenizer.encode(text) tokens = encoding.ids #print(tokens) if len(tokens) < self.block_size: continue for i in range(0, len(tokens) - self.block_size + 1, self.block_size): chunk = tokens[i:i + self.block_size] if len(chunk) == self.block_size: chunks.append(chunk) return chunks def __len__(self): return len(self.data) def __getitem__(self, idx): chunk = self.data[idx] x = torch.tensor(chunk[:-1]) y = torch.tensor(chunk[1:]) return x, y class MiniBPETokenizr: def __init__(self): self.stoi = {} self.itos = {} self.vocab_size = 0 def tokenize(self, text): text = text.lower().strip() words = re.findall(r"[a-zA-Z0-9]+|[^\w\s]", text) return [list(w) + [''] if w.isalnum() else [w] for w in words] def get_stats(self, corpus): pairs = Counter() for tokens in corpus: for i in range(len(tokens) - 1): pairs[(tokens[i], tokens[i + 1])] += 1 return pairs def merge_vocab(self, corpus, pair_to_merge): bigram = re.escape(' '.join(pair_to_merge)) pattern = re.compile(r'(?", "", "", "^user:", "minigpt:"]) self.stoi = {tok: i for i, tok in enumerate(sorted(vocab))} self.itos = {i: tok for tok, i in self.stoi.items()} self.vocab_size = len(self.stoi) def encode(self, text): tokens = sum(self.tokenize(text), []) output = [] i = 0 while i < len(tokens): j = len(tokens) while j > i: candidate = ''.join(tokens[i:j]) if candidate in self.stoi: output.append(self.stoi[candidate]) i = j break j -= 1 else: output.append(self.stoi.get("", 1)) i += 1 return output def decode(self, token_ids): tokens = [self.itos.get(i, "") for i in token_ids] text = ' '.join(t.replace('', '') for t in tokens if t not in {"", "", ""}) text = re.sub(r'\s([?.!,:;])', r'\1', text) return text.strip() def save(self, path): with open(path, "w", encoding="utf-8") as f: json.dump({"stoi": self.stoi, "itos": self.itos}, f) def load(self, path): with open(path, "r", encoding="utf-8") as f: data = json.load(f) self.stoi = {k: int(v) for k, v in data["stoi"].items()} self.itos = {int(v): k for k, v in self.stoi.items()} self.vocab_size = len(self.stoi) class SimpleTokenizr: def __init__(self): self.stoi = {} self.itos = {} def tokenize(self, text): return re.findall(r"[a-zA-Z']+|\d+|[^\w\s]", text.lower()) def train(self, texts): vocab = set() for text in texts: tokens = self.tokenize(text) vocab.update(tokens) vocab.update(["", "", "", "^user :", "minigpt :", "MiniGPT :", ":"]) sorted_vocab = sorted(vocab) self.stoi = {token: idx for idx, token in enumerate(sorted_vocab)} self.itos = {idx: token for token, idx in self.stoi.items()} def encode(self, text): tokens = self.tokenize(text) return [self.stoi.get(tok, self.stoi[""]) for tok in tokens] + [self.stoi[""]] def decode(self, token_ids): tokens = [self.itos.get(i, "") for i in token_ids] clean_tokens = [tok for tok in tokens if tok not in {"", "", ""}] text = '' for i, tok in enumerate(clean_tokens): if re.match(r"[.,!?;:]", tok): text += tok elif i > 0: text += ' ' + tok else: text += tok return text.strip().capitalize() def save(self, path): with open(path, "w", encoding="utf-8") as f: json.dump({"stoi": self.stoi, "itos": self.itos}, f) def load(self, path): with open(path, "r", encoding="utf-8") as f: data = json.load(f) self.stoi = {k: int(v) for k, v in data["stoi"].items()} self.itos = {int(k): v for v, k in self.stoi.items()} def __len__(self): return len(self.stoi) @property def vocab_size(self): return len(self.stoi) def validate(model, dataloader, device): model.eval() total_loss, correct, total = 0, 0, 0 with torch.no_grad(): for x, y in dataloader: x, y = x.to(device), y.to(device) logits = model(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) total_loss += loss.item() preds = torch.argmax(logits, dim=-1) correct += (preds == y).sum().item() total += y.numel() avg_loss = total_loss / len(dataloader) accuracy = 100 * correct / total return avg_loss, accuracy def train(model, dataset, tokenizer, epochs, filepathh, start_epoch=0, start_step=0): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # 🔀 Proper train/val split val_size = int(0.1 * len(dataset)) train_size = len(dataset) - val_size train_set, val_set = random_split(dataset, [train_size, val_size]) train_loader = DataLoader(train_set, batch_size=10, shuffle=True, num_workers=2) val_loader = DataLoader(val_set, batch_size=10, shuffle=False, num_workers=2) optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) checkpoint_path = "./trained-mini-gpt/checkpoint-mini-gpt.pth" if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_epoch = checkpoint["epoch"] start_step = checkpoint["step"] else: model.load_state_dict(checkpoint) else: print("🚀 Starting from scratch.") total_steps = start_step for epoch in range(start_epoch, epochs): model.train() total_loss, correct, total = 0, 0, 0 loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}") for step, (x, y) in loop: x, y = x.to(device), y.to(device) logits = model(x) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() preds = torch.argmax(logits, dim=-1) correct += (preds == y).sum().item() total += y.numel() acc = 100 * correct / total loop.set_postfix(loss=loss.item(), acc=acc) #if step % 100 == 0: # torch.save({ # "model_state_dict": model.state_dict(), # "optimizer_state_dict": optimizer.state_dict(), # "epoch": epoch, # "step": total_steps # }, checkpoint_path) # 🔍 Validate after each epoch val_loss, val_acc = validate(model, val_loader, device) print(f"✅ Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%") # 💾 Save checkpoint torch.save({ "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "step": total_steps }, checkpoint_path) torch.save(model.state_dict(), "./trained-mini-gpt/mini-gpt.pth") print("🎉 Training complete.")