gpt-moe-mcts / q_star.py
RobbiePasquale's picture
Upload 3 files
8e083dc verified
raw
history blame
19.2 kB
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import os
import numpy as np
import time
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
import random
from collections import defaultdict
from torch.cuda.amp import autocast
from typing import List, Tuple
from torch.nn.utils.rnn import pad_sequence
import inspect
# Define your dataset and dataloader classes
class NpyDataset(Dataset):
def __init__(self, data_dir, file_prefix):
self.data_dir = data_dir
self.file_names = [os.path.join(data_dir, f) for f in sorted(os.listdir(data_dir)) if f.startswith(file_prefix) and f.endswith('.npy')]
def __len__(self):
return len(self.file_names)
def __getitem__(self, idx):
tokens_np = np.load(self.file_names[idx])
tokens_tensor = torch.tensor(tokens_np, dtype=torch.long)
return tokens_tensor
class CustomDataLoaderLite:
def __init__(self, dataset, batch_size, seq_len):
self.dataset = dataset
self.batch_size = batch_size
self.seq_len = seq_len
self.current_position = 0
def __iter__(self):
self.current_position = 0
return self
def __next__(self):
if self.current_position >= len(self.dataset):
raise StopIteration
batch = []
for _ in range(self.batch_size):
if self.current_position >= len(self.dataset):
break
tokens = self.dataset[self.current_position]
batch.append(tokens[:self.seq_len])
self.current_position += 1
x = torch.stack([tokens[:-1] for tokens in batch])
y = torch.stack([tokens[1:] for tokens in batch])
return x, y
def __len__(self):
return (len(self.dataset) + self.batch_size - 1) // self.batch_size
# Define the FlashAttention3 module
class FlashAttention3(nn.Module):
def __init__(self, d_model, n_heads, block_size_q, block_size_kv, num_blocks_kv, device='cuda'):
super(FlashAttention3, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
self.block_size_q = block_size_q
self.block_size_kv = block_size_kv
self.num_blocks_kv = num_blocks_kv
self.device = device
self.q_proj = nn.Linear(d_model, d_model).to(device)
self.k_proj = nn.Linear(d_model, d_model).to(device)
self.v_proj = nn.Linear(d_model, d_model).to(device)
self.out_proj = nn.Linear(d_model, d_model).to(device)
def forward(self, x):
B, T, C = x.size()
Q = self.q_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
K = self.k_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
V = self.v_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
O = torch.zeros(B, self.n_heads, T, C // self.n_heads).to(self.device)
L = torch.zeros(B, self.n_heads, T).to(self.device)
M = torch.full((B, self.n_heads, T), -float('inf')).to(self.device)
for i in range(0, T, self.block_size_q):
Q_block = Q[:, :, i:i+self.block_size_q]
O_block = torch.zeros_like(Q_block).to(self.device)
L_block = torch.zeros(B, self.n_heads, Q_block.size(2)).to(self.device)
M_block = torch.full((B, self.n_heads, Q_block.size(2)), -float('inf')).to(self.device)
for j in range(0, T, self.block_size_kv):
K_block = K[:, :, j:j+self.block_size_kv]
V_block = V[:, :, j:j+self.block_size_kv]
S_block = torch.matmul(Q_block, K_block.transpose(-2, -1))
M_block_old = M_block
M_block = torch.max(M_block, S_block.max(dim=-1).values)
exp_S_block = torch.exp(S_block - M_block.unsqueeze(-1))
L_block = torch.exp(M_block_old - M_block) * L_block + exp_S_block.sum(dim=-1)
O_block += torch.matmul(exp_S_block, V_block)
O_block /= L_block.unsqueeze(-1)
O[:, :, i:i+self.block_size_q] = O_block
O = O.transpose(1, 2).contiguous().view(B, T, self.n_heads * (C // self.n_heads))
O = self.out_proj(O)
return O
# Define the MLP module
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
self.gelu = nn.GELU(approximate='tanh')
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
self.dropout = nn.Dropout(config.dropout)
self.c_proj.scale_init = 1
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
# Define the MixtureOfExperts module
class MixtureOfExperts(nn.Module):
def __init__(self, config, num_experts, expert_layers):
super().__init__()
self.num_experts = num_experts
self.expert_layers = expert_layers
self.experts = nn.ModuleList([self._create_expert(config) for _ in range(num_experts)])
self.gate = nn.Linear(config.n_embd, num_experts)
def _create_expert(self, config):
layers = []
for _ in range(self.expert_layers):
layers.append(FlashAttention3(d_model=config.n_embd, n_heads=config.n_head, block_size_q=32, block_size_kv=32, num_blocks_kv=4))
layers.append(nn.LayerNorm(config.n_embd))
layers.append(MLP(config))
return nn.Sequential(*layers)
def forward(self, x):
B, T, C = x.size()
gate_scores = self.gate(x)
gate_probs = F.softmax(gate_scores, dim=-1)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
gate_probs = gate_probs.unsqueeze(-1)
gate_probs = gate_probs.permute(0, 2, 1, 3)
output = torch.sum(gate_probs * expert_outputs, dim=1)
return output
# Define the BlockWithMoE module
class BlockWithMoE(nn.Module):
def __init__(self, config, num_experts=4, expert_layers=2, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device='cuda'):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = FlashAttention3(d_model=config.n_embd, n_heads=config.n_head, block_size_q=block_size_q, block_size_kv=block_size_kv, num_blocks_kv=num_blocks_kv, device=device)
self.dropout1 = nn.Dropout(config.dropout)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.moe = MixtureOfExperts(config, num_experts, expert_layers)
self.dropout2 = nn.Dropout(config.dropout)
self.ln_3 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
self.dropout3 = nn.Dropout(config.dropout)
def forward(self, x):
B, T, C = x.size()
attn_output = self.attn(x)
x = x + attn_output
x = self.dropout1(x)
x = x + self.moe(self.ln_2(x))
x = self.dropout2(x)
x = x + self.mlp(self.ln_3(x))
x = self.dropout3(x)
return x
# Define the GPT configuration dataclass
@dataclass
class GPTConfig:
block_size: int = 512
vocab_size: int = 50257
n_layer: int = 6
n_head: int = 4
n_embd: int = 256
dropout: float = 0.2
# Define the GPTWithMoE model
class GPTWithMoE(nn.Module):
def __init__(self, config, num_experts=2, expert_layers=2, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device='cuda'):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
wpe=nn.Embedding(config.block_size, config.n_embd),
h=nn.ModuleList([BlockWithMoE(config, num_experts, expert_layers, block_size_q, block_size_kv, num_blocks_kv, device) for _ in range(config.n_layer)]),
ln_f=nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
std = 0.02
if hasattr(module, 'scale_init'):
std *= (2 * self.config.n_layer) ** -0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
pos_emb = self.transformer.wpe(pos)
tok_emb = self.transformer.wte(idx)
x = tok_emb + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
def configure_optimizers(self, weight_decay, learning_rate, device):
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
non_decay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': non_decay_params, 'weight_decay': 0}
]
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and 'cuda' in device
print(f" Using fused AdamW: {use_fused}")
optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
return optimizer
# MCTS Implementation
@dataclass
class MCTSNode:
state: torch.Tensor
parent: 'MCTSNode' = None
children: dict = None
visits: int = 0
value: float = 0.0
def __post_init__(self):
if self.children is None:
self.children = {}
# Define scriptable functions separately
def select_node(node: MCTSNode, c_puct: float) -> MCTSNode:
if not node.children:
return node
scores = torch.tensor([
child.value / (child.visits + 1e-8) +
c_puct * math.sqrt(math.log(node.visits + 1) / (child.visits + 1e-8))
for child in node.children.values()
])
best_child_idx = torch.argmax(scores).item()
return list(node.children.values())[best_child_idx]
def expand_node(node: MCTSNode, logits: torch.Tensor, top_k: int) -> None:
probs = F.softmax(logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
for prob, token in zip(top_k_probs, top_k_indices):
if token.item() not in node.children:
node.children[token.item()] = MCTSNode(state=token, parent=node)
def simulate(model: torch.nn.Module, sequence: torch.Tensor, max_length: int) -> torch.Tensor:
# Ensure sequence is 2D
if sequence.dim() == 1:
sequence = sequence.unsqueeze(0)
with torch.no_grad():
while sequence.size(1) < max_length:
with autocast():
logits, _ = model(sequence)
probs = F.softmax(logits[0, -1], dim=-1)
next_token = torch.multinomial(probs, 1)
sequence = torch.cat([sequence, next_token.unsqueeze(0)], dim=1)
return sequence.squeeze(0)
def backpropagate(node: MCTSNode, value: float) -> None:
while node is not None:
node.visits += 1
node.value += value
node = node.parent
def mcts_decode_single(model: torch.nn.Module, input_ids: torch.Tensor, max_length: int, num_simulations: int, c_puct: float, top_k: int) -> torch.Tensor:
# Ensure input_ids is 2D
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
root = MCTSNode(state=input_ids)
for _ in range(num_simulations):
node = root
current_input = input_ids.clone()
# Selection
while node.children and current_input.size(1) < max_length:
node = select_node(node, c_puct)
current_input = torch.cat([current_input, node.state.unsqueeze(0).unsqueeze(0)], dim=1)
# Expansion
if current_input.size(1) < max_length:
with torch.no_grad():
with autocast():
logits, _ = model(current_input)
expand_node(node, logits[0, -1], top_k)
# Simulation
simulation_sequence = simulate(model, current_input.squeeze(0), max_length)
# Evaluation
with torch.no_grad():
with autocast():
_, loss = model(simulation_sequence.unsqueeze(0), simulation_sequence.unsqueeze(0))
value = -loss.item()
# Backpropagation
backpropagate(node, value)
# Choose the best next token
best_child = max(root.children.values(), key=lambda n: n.visits)
result = torch.cat([input_ids.squeeze(0), best_child.state.unsqueeze(0)], dim=0)
# Ensure the result doesn't exceed max_length
return result[:max_length]
def mcts_decode_batch(model: torch.nn.Module, input_ids_list: List[torch.Tensor], max_length: int, num_simulations: int, c_puct: float, top_k: int) -> List[torch.Tensor]:
return [mcts_decode_single(model, input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids, max_length, num_simulations, c_puct, top_k) for input_ids in input_ids_list]
def validate_with_mcts(model: torch.nn.Module, val_dataloader: CustomDataLoaderLite, device: torch.device, max_length: int, num_simulations: int, c_puct: float, top_k: int) -> float:
model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for x, y in val_dataloader:
x, y = x.to(device), y.to(device)
# Use MCTS for decoding
decoded_sequences = mcts_decode_batch(model, x, max_length, num_simulations, c_puct, top_k)
# Pad sequences to the same length
decoded_sequences_padded = pad_sequence(decoded_sequences, batch_first=True, padding_value=0)
# Trim the decoded sequences to match the target length
decoded_sequences_trimmed = decoded_sequences_padded[:, :y.size(1)]
# Calculate loss using the MCTS-decoded sequences
with autocast():
logits, loss = model(decoded_sequences_trimmed, y)
total_loss += loss.item()
num_batches += 1
return total_loss / num_batches if num_batches > 0 else 0.0
def train_model():
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = 'mps'
print(f"using device : {device}")
# Load the dataset and create the data loader
print("Loading datasets...")
train_dataset = NpyDataset('edu_fineweb10B', 'edufineweb_train')
val_dataset = NpyDataset('edu_fineweb10B', 'edufineweb_val')
train_dataloader = CustomDataLoaderLite(train_dataset, batch_size=12, seq_len=512)
val_dataloader = CustomDataLoaderLite(val_dataset, batch_size=12, seq_len=512)
# Training loop
max_steps = 200
total_batch_size = 262144
B = 12
T = 512
grad_accum_steps = total_batch_size // (B * T)
# Set up the configuration
print("Setting up model configuration...")
config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256)
# Initialize the model
print("Initializing model...")
model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device)
model.to(device)
# Load the saved model weights if they exist
save_path = "C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt"
temp_save_path = "C:\\Users\\Admin\\MODELS\\moe_mcts_temp_new.pt"
if os.path.isfile(save_path):
print(f"Loading model weights from {save_path}...")
model.load_state_dict(torch.load(save_path))
print(f"Loaded model weights from {save_path}")
print("Configuring optimizer...")
optimizer = model.configure_optimizers(weight_decay=0.2, learning_rate=3e-3, device=device)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
train_losses = []
val_losses = []
scaler = torch.cuda.amp.GradScaler()
for i in range(max_steps):
t0 = time.time()
optimizer.zero_grad()
train_loss_accum = 0
model.train()
print(f"Training step {i + 1}/{max_steps}...")
for x, y in train_dataloader:
x, y = x.to(device), y.to(device)
with torch.cuda.amp.autocast():
logits, loss = model(x, y)
loss = loss / grad_accum_steps
train_loss_accum += loss.detach()
scaler.scale(loss).backward()
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t1 = time.time()
dt = (t1 - t0) * 1000
tokens_per_sec = (B * T * grad_accum_steps) / (t1 - t0)
train_losses.append(train_loss_accum.item())
torch.cuda.empty_cache()
# Validation with MCTS
model.eval()
val_loss = validate_with_mcts(model, val_dataloader, device, max_length=T, num_simulations=100, c_puct=1.0, top_k=10)
val_losses.append(val_loss)
scheduler.step(val_loss)
print(f"step {i} | train loss: {train_loss_accum.item():.6f} | val loss: {val_loss:.6f} | lr: {optimizer.param_groups[0]['lr']:.8f} | norm: {norm:.4f} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec}")
# Save model weights
torch.save(model.state_dict(), temp_save_path)
os.replace(temp_save_path, save_path)
print(f"Model saved at step {i+1} to {save_path}")
# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.legend()
plt.show()
if __name__ == "__main__":
train_model()