|
|
|
from dataclasses import dataclass
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
import math
|
|
import os
|
|
import numpy as np
|
|
import time
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import matplotlib.pyplot as plt
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
import random
|
|
from collections import defaultdict
|
|
from torch.cuda.amp import autocast
|
|
from typing import List, Tuple
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
import inspect
|
|
|
|
|
|
class NpyDataset(Dataset):
|
|
def __init__(self, data_dir, file_prefix):
|
|
self.data_dir = data_dir
|
|
self.file_names = [os.path.join(data_dir, f) for f in sorted(os.listdir(data_dir)) if f.startswith(file_prefix) and f.endswith('.npy')]
|
|
|
|
def __len__(self):
|
|
return len(self.file_names)
|
|
|
|
def __getitem__(self, idx):
|
|
tokens_np = np.load(self.file_names[idx])
|
|
tokens_tensor = torch.tensor(tokens_np, dtype=torch.long)
|
|
return tokens_tensor
|
|
|
|
class CustomDataLoaderLite:
|
|
def __init__(self, dataset, batch_size, seq_len):
|
|
self.dataset = dataset
|
|
self.batch_size = batch_size
|
|
self.seq_len = seq_len
|
|
self.current_position = 0
|
|
|
|
def __iter__(self):
|
|
self.current_position = 0
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.current_position >= len(self.dataset):
|
|
raise StopIteration
|
|
|
|
batch = []
|
|
for _ in range(self.batch_size):
|
|
if self.current_position >= len(self.dataset):
|
|
break
|
|
tokens = self.dataset[self.current_position]
|
|
batch.append(tokens[:self.seq_len])
|
|
self.current_position += 1
|
|
|
|
x = torch.stack([tokens[:-1] for tokens in batch])
|
|
y = torch.stack([tokens[1:] for tokens in batch])
|
|
|
|
return x, y
|
|
|
|
def __len__(self):
|
|
return (len(self.dataset) + self.batch_size - 1) // self.batch_size
|
|
|
|
|
|
class FlashAttention3(nn.Module):
|
|
def __init__(self, d_model, n_heads, block_size_q, block_size_kv, num_blocks_kv, device='cuda'):
|
|
super(FlashAttention3, self).__init__()
|
|
self.d_model = d_model
|
|
self.n_heads = n_heads
|
|
self.block_size_q = block_size_q
|
|
self.block_size_kv = block_size_kv
|
|
self.num_blocks_kv = num_blocks_kv
|
|
self.device = device
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model).to(device)
|
|
self.k_proj = nn.Linear(d_model, d_model).to(device)
|
|
self.v_proj = nn.Linear(d_model, d_model).to(device)
|
|
self.out_proj = nn.Linear(d_model, d_model).to(device)
|
|
|
|
def forward(self, x):
|
|
B, T, C = x.size()
|
|
Q = self.q_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
|
|
K = self.k_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
|
|
V = self.v_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
|
|
|
|
O = torch.zeros(B, self.n_heads, T, C // self.n_heads).to(self.device)
|
|
L = torch.zeros(B, self.n_heads, T).to(self.device)
|
|
M = torch.full((B, self.n_heads, T), -float('inf')).to(self.device)
|
|
|
|
for i in range(0, T, self.block_size_q):
|
|
Q_block = Q[:, :, i:i+self.block_size_q]
|
|
O_block = torch.zeros_like(Q_block).to(self.device)
|
|
L_block = torch.zeros(B, self.n_heads, Q_block.size(2)).to(self.device)
|
|
M_block = torch.full((B, self.n_heads, Q_block.size(2)), -float('inf')).to(self.device)
|
|
|
|
for j in range(0, T, self.block_size_kv):
|
|
K_block = K[:, :, j:j+self.block_size_kv]
|
|
V_block = V[:, :, j:j+self.block_size_kv]
|
|
|
|
S_block = torch.matmul(Q_block, K_block.transpose(-2, -1))
|
|
M_block_old = M_block
|
|
M_block = torch.max(M_block, S_block.max(dim=-1).values)
|
|
|
|
exp_S_block = torch.exp(S_block - M_block.unsqueeze(-1))
|
|
L_block = torch.exp(M_block_old - M_block) * L_block + exp_S_block.sum(dim=-1)
|
|
|
|
O_block += torch.matmul(exp_S_block, V_block)
|
|
|
|
O_block /= L_block.unsqueeze(-1)
|
|
O[:, :, i:i+self.block_size_q] = O_block
|
|
|
|
O = O.transpose(1, 2).contiguous().view(B, T, self.n_heads * (C // self.n_heads))
|
|
O = self.out_proj(O)
|
|
|
|
return O
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
|
self.gelu = nn.GELU(approximate='tanh')
|
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
|
self.dropout = nn.Dropout(config.dropout)
|
|
self.c_proj.scale_init = 1
|
|
|
|
def forward(self, x):
|
|
x = self.c_fc(x)
|
|
x = self.gelu(x)
|
|
x = self.c_proj(x)
|
|
x = self.dropout(x)
|
|
return x
|
|
|
|
|
|
class MixtureOfExperts(nn.Module):
|
|
def __init__(self, config, num_experts, expert_layers):
|
|
super().__init__()
|
|
self.num_experts = num_experts
|
|
self.expert_layers = expert_layers
|
|
|
|
self.experts = nn.ModuleList([self._create_expert(config) for _ in range(num_experts)])
|
|
self.gate = nn.Linear(config.n_embd, num_experts)
|
|
|
|
def _create_expert(self, config):
|
|
layers = []
|
|
for _ in range(self.expert_layers):
|
|
layers.append(FlashAttention3(d_model=config.n_embd, n_heads=config.n_head, block_size_q=32, block_size_kv=32, num_blocks_kv=4))
|
|
layers.append(nn.LayerNorm(config.n_embd))
|
|
layers.append(MLP(config))
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
B, T, C = x.size()
|
|
|
|
gate_scores = self.gate(x)
|
|
gate_probs = F.softmax(gate_scores, dim=-1)
|
|
|
|
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
|
|
|
|
gate_probs = gate_probs.unsqueeze(-1)
|
|
gate_probs = gate_probs.permute(0, 2, 1, 3)
|
|
|
|
output = torch.sum(gate_probs * expert_outputs, dim=1)
|
|
|
|
return output
|
|
|
|
|
|
class BlockWithMoE(nn.Module):
|
|
def __init__(self, config, num_experts=4, expert_layers=2, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device='cuda'):
|
|
super().__init__()
|
|
self.ln_1 = nn.LayerNorm(config.n_embd)
|
|
self.attn = FlashAttention3(d_model=config.n_embd, n_heads=config.n_head, block_size_q=block_size_q, block_size_kv=block_size_kv, num_blocks_kv=num_blocks_kv, device=device)
|
|
self.dropout1 = nn.Dropout(config.dropout)
|
|
self.ln_2 = nn.LayerNorm(config.n_embd)
|
|
self.moe = MixtureOfExperts(config, num_experts, expert_layers)
|
|
self.dropout2 = nn.Dropout(config.dropout)
|
|
self.ln_3 = nn.LayerNorm(config.n_embd)
|
|
self.mlp = MLP(config)
|
|
self.dropout3 = nn.Dropout(config.dropout)
|
|
|
|
def forward(self, x):
|
|
B, T, C = x.size()
|
|
|
|
attn_output = self.attn(x)
|
|
x = x + attn_output
|
|
x = self.dropout1(x)
|
|
x = x + self.moe(self.ln_2(x))
|
|
x = self.dropout2(x)
|
|
x = x + self.mlp(self.ln_3(x))
|
|
x = self.dropout3(x)
|
|
return x
|
|
|
|
|
|
@dataclass
|
|
class GPTConfig:
|
|
block_size: int = 512
|
|
vocab_size: int = 50257
|
|
n_layer: int = 6
|
|
n_head: int = 4
|
|
n_embd: int = 256
|
|
dropout: float = 0.2
|
|
|
|
|
|
class GPTWithMoE(nn.Module):
|
|
def __init__(self, config, num_experts=2, expert_layers=2, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device='cuda'):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
self.transformer = nn.ModuleDict(dict(
|
|
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
|
wpe=nn.Embedding(config.block_size, config.n_embd),
|
|
h=nn.ModuleList([BlockWithMoE(config, num_experts, expert_layers, block_size_q, block_size_kv, num_blocks_kv, device) for _ in range(config.n_layer)]),
|
|
ln_f=nn.LayerNorm(config.n_embd),
|
|
))
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
self.transformer.wte.weight = self.lm_head.weight
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, module):
|
|
if isinstance(module, nn.Linear):
|
|
std = 0.02
|
|
if hasattr(module, 'scale_init'):
|
|
std *= (2 * self.config.n_layer) ** -0.5
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
elif isinstance(module, nn.Embedding):
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
|
def forward(self, idx, targets=None):
|
|
B, T = idx.size()
|
|
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
|
|
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
|
|
pos_emb = self.transformer.wpe(pos)
|
|
tok_emb = self.transformer.wte(idx)
|
|
x = tok_emb + pos_emb
|
|
for block in self.transformer.h:
|
|
x = block(x)
|
|
x = self.transformer.ln_f(x)
|
|
logits = self.lm_head(x)
|
|
loss = None
|
|
if targets is not None:
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
return logits, loss
|
|
|
|
def configure_optimizers(self, weight_decay, learning_rate, device):
|
|
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
|
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
|
non_decay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
|
|
|
optim_groups = [
|
|
{'params': decay_params, 'weight_decay': weight_decay},
|
|
{'params': non_decay_params, 'weight_decay': 0}
|
|
]
|
|
|
|
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
|
use_fused = fused_available and 'cuda' in device
|
|
print(f" Using fused AdamW: {use_fused}")
|
|
optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
|
|
return optimizer
|
|
|
|
|
|
@dataclass
|
|
class MCTSNode:
|
|
state: torch.Tensor
|
|
parent: 'MCTSNode' = None
|
|
children: dict = None
|
|
visits: int = 0
|
|
value: float = 0.0
|
|
|
|
def __post_init__(self):
|
|
if self.children is None:
|
|
self.children = {}
|
|
|
|
|
|
def select_node(node: MCTSNode, c_puct: float) -> MCTSNode:
|
|
if not node.children:
|
|
return node
|
|
|
|
scores = torch.tensor([
|
|
child.value / (child.visits + 1e-8) +
|
|
c_puct * math.sqrt(math.log(node.visits + 1) / (child.visits + 1e-8))
|
|
for child in node.children.values()
|
|
])
|
|
|
|
best_child_idx = torch.argmax(scores).item()
|
|
return list(node.children.values())[best_child_idx]
|
|
|
|
def expand_node(node: MCTSNode, logits: torch.Tensor, top_k: int) -> None:
|
|
probs = F.softmax(logits, dim=-1)
|
|
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
|
|
|
|
for prob, token in zip(top_k_probs, top_k_indices):
|
|
if token.item() not in node.children:
|
|
node.children[token.item()] = MCTSNode(state=token, parent=node)
|
|
|
|
def simulate(model: torch.nn.Module, sequence: torch.Tensor, max_length: int) -> torch.Tensor:
|
|
|
|
if sequence.dim() == 1:
|
|
sequence = sequence.unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
while sequence.size(1) < max_length:
|
|
with autocast():
|
|
logits, _ = model(sequence)
|
|
probs = F.softmax(logits[0, -1], dim=-1)
|
|
next_token = torch.multinomial(probs, 1)
|
|
sequence = torch.cat([sequence, next_token.unsqueeze(0)], dim=1)
|
|
return sequence.squeeze(0)
|
|
|
|
def backpropagate(node: MCTSNode, value: float) -> None:
|
|
while node is not None:
|
|
node.visits += 1
|
|
node.value += value
|
|
node = node.parent
|
|
|
|
def mcts_decode_single(model: torch.nn.Module, input_ids: torch.Tensor, max_length: int, num_simulations: int, c_puct: float, top_k: int) -> torch.Tensor:
|
|
|
|
if input_ids.dim() == 1:
|
|
input_ids = input_ids.unsqueeze(0)
|
|
|
|
root = MCTSNode(state=input_ids)
|
|
|
|
for _ in range(num_simulations):
|
|
node = root
|
|
current_input = input_ids.clone()
|
|
|
|
|
|
while node.children and current_input.size(1) < max_length:
|
|
node = select_node(node, c_puct)
|
|
current_input = torch.cat([current_input, node.state.unsqueeze(0).unsqueeze(0)], dim=1)
|
|
|
|
|
|
if current_input.size(1) < max_length:
|
|
with torch.no_grad():
|
|
with autocast():
|
|
logits, _ = model(current_input)
|
|
expand_node(node, logits[0, -1], top_k)
|
|
|
|
|
|
simulation_sequence = simulate(model, current_input.squeeze(0), max_length)
|
|
|
|
|
|
with torch.no_grad():
|
|
with autocast():
|
|
_, loss = model(simulation_sequence.unsqueeze(0), simulation_sequence.unsqueeze(0))
|
|
value = -loss.item()
|
|
|
|
|
|
backpropagate(node, value)
|
|
|
|
|
|
best_child = max(root.children.values(), key=lambda n: n.visits)
|
|
result = torch.cat([input_ids.squeeze(0), best_child.state.unsqueeze(0)], dim=0)
|
|
|
|
|
|
return result[:max_length]
|
|
|
|
def mcts_decode_batch(model: torch.nn.Module, input_ids_list: List[torch.Tensor], max_length: int, num_simulations: int, c_puct: float, top_k: int) -> List[torch.Tensor]:
|
|
return [mcts_decode_single(model, input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids, max_length, num_simulations, c_puct, top_k) for input_ids in input_ids_list]
|
|
|
|
def validate_with_mcts(model: torch.nn.Module, val_dataloader: CustomDataLoaderLite, device: torch.device, max_length: int, num_simulations: int, c_puct: float, top_k: int) -> float:
|
|
model.eval()
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
|
|
with torch.no_grad():
|
|
for x, y in val_dataloader:
|
|
x, y = x.to(device), y.to(device)
|
|
|
|
|
|
decoded_sequences = mcts_decode_batch(model, x, max_length, num_simulations, c_puct, top_k)
|
|
|
|
|
|
decoded_sequences_padded = pad_sequence(decoded_sequences, batch_first=True, padding_value=0)
|
|
|
|
|
|
decoded_sequences_trimmed = decoded_sequences_padded[:, :y.size(1)]
|
|
|
|
|
|
with autocast():
|
|
logits, loss = model(decoded_sequences_trimmed, y)
|
|
total_loss += loss.item()
|
|
num_batches += 1
|
|
|
|
return total_loss / num_batches if num_batches > 0 else 0.0
|
|
|
|
def train_model():
|
|
device = 'cpu'
|
|
if torch.cuda.is_available():
|
|
device = 'cuda'
|
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
device = 'mps'
|
|
print(f"using device : {device}")
|
|
|
|
|
|
print("Loading datasets...")
|
|
train_dataset = NpyDataset('edu_fineweb10B', 'edufineweb_train')
|
|
val_dataset = NpyDataset('edu_fineweb10B', 'edufineweb_val')
|
|
train_dataloader = CustomDataLoaderLite(train_dataset, batch_size=12, seq_len=512)
|
|
val_dataloader = CustomDataLoaderLite(val_dataset, batch_size=12, seq_len=512)
|
|
|
|
|
|
max_steps = 200
|
|
total_batch_size = 262144
|
|
B = 12
|
|
T = 512
|
|
grad_accum_steps = total_batch_size // (B * T)
|
|
|
|
|
|
print("Setting up model configuration...")
|
|
config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256)
|
|
|
|
|
|
print("Initializing model...")
|
|
model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device)
|
|
model.to(device)
|
|
|
|
|
|
save_path = "C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt"
|
|
temp_save_path = "C:\\Users\\Admin\\MODELS\\moe_mcts_temp_new.pt"
|
|
if os.path.isfile(save_path):
|
|
print(f"Loading model weights from {save_path}...")
|
|
model.load_state_dict(torch.load(save_path))
|
|
print(f"Loaded model weights from {save_path}")
|
|
|
|
print("Configuring optimizer...")
|
|
optimizer = model.configure_optimizers(weight_decay=0.2, learning_rate=3e-3, device=device)
|
|
|
|
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
|
|
|
|
train_losses = []
|
|
val_losses = []
|
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
for i in range(max_steps):
|
|
t0 = time.time()
|
|
optimizer.zero_grad()
|
|
train_loss_accum = 0
|
|
|
|
model.train()
|
|
print(f"Training step {i + 1}/{max_steps}...")
|
|
for x, y in train_dataloader:
|
|
x, y = x.to(device), y.to(device)
|
|
with torch.cuda.amp.autocast():
|
|
logits, loss = model(x, y)
|
|
loss = loss / grad_accum_steps
|
|
train_loss_accum += loss.detach()
|
|
scaler.scale(loss).backward()
|
|
|
|
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
torch.cuda.synchronize()
|
|
t1 = time.time()
|
|
dt = (t1 - t0) * 1000
|
|
tokens_per_sec = (B * T * grad_accum_steps) / (t1 - t0)
|
|
train_losses.append(train_loss_accum.item())
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
model.eval()
|
|
val_loss = validate_with_mcts(model, val_dataloader, device, max_length=T, num_simulations=100, c_puct=1.0, top_k=10)
|
|
val_losses.append(val_loss)
|
|
|
|
scheduler.step(val_loss)
|
|
|
|
print(f"step {i} | train loss: {train_loss_accum.item():.6f} | val loss: {val_loss:.6f} | lr: {optimizer.param_groups[0]['lr']:.8f} | norm: {norm:.4f} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec}")
|
|
|
|
|
|
torch.save(model.state_dict(), temp_save_path)
|
|
os.replace(temp_save_path, save_path)
|
|
print(f"Model saved at step {i+1} to {save_path}")
|
|
|
|
|
|
plt.figure(figsize=(10, 5))
|
|
plt.plot(train_losses, label='Training Loss')
|
|
plt.plot(val_losses, label='Validation Loss')
|
|
plt.xlabel('Steps')
|
|
plt.ylabel('Loss')
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
if __name__ == "__main__":
|
|
train_model()
|
|
|
|
|
|
|