from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F import math import os import numpy as np import time from torch.utils.data import Dataset, DataLoader import matplotlib.pyplot as plt from torch.optim.lr_scheduler import ReduceLROnPlateau import random from collections import defaultdict from torch.cuda.amp import autocast from typing import List, Tuple from torch.nn.utils.rnn import pad_sequence import inspect # Define your dataset and dataloader classes class NpyDataset(Dataset): def __init__(self, data_dir, file_prefix): self.data_dir = data_dir self.file_names = [os.path.join(data_dir, f) for f in sorted(os.listdir(data_dir)) if f.startswith(file_prefix) and f.endswith('.npy')] def __len__(self): return len(self.file_names) def __getitem__(self, idx): tokens_np = np.load(self.file_names[idx]) tokens_tensor = torch.tensor(tokens_np, dtype=torch.long) return tokens_tensor class CustomDataLoaderLite: def __init__(self, dataset, batch_size, seq_len): self.dataset = dataset self.batch_size = batch_size self.seq_len = seq_len self.current_position = 0 def __iter__(self): self.current_position = 0 return self def __next__(self): if self.current_position >= len(self.dataset): raise StopIteration batch = [] for _ in range(self.batch_size): if self.current_position >= len(self.dataset): break tokens = self.dataset[self.current_position] batch.append(tokens[:self.seq_len]) self.current_position += 1 x = torch.stack([tokens[:-1] for tokens in batch]) y = torch.stack([tokens[1:] for tokens in batch]) return x, y def __len__(self): return (len(self.dataset) + self.batch_size - 1) // self.batch_size # Define the FlashAttention3 module class FlashAttention3(nn.Module): def __init__(self, d_model, n_heads, block_size_q, block_size_kv, num_blocks_kv, device='cuda'): super(FlashAttention3, self).__init__() self.d_model = d_model self.n_heads = n_heads self.block_size_q = block_size_q self.block_size_kv = block_size_kv self.num_blocks_kv = num_blocks_kv self.device = device self.q_proj = nn.Linear(d_model, d_model).to(device) self.k_proj = nn.Linear(d_model, d_model).to(device) self.v_proj = nn.Linear(d_model, d_model).to(device) self.out_proj = nn.Linear(d_model, d_model).to(device) def forward(self, x): B, T, C = x.size() Q = self.q_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) K = self.k_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) V = self.v_proj(x).view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) O = torch.zeros(B, self.n_heads, T, C // self.n_heads).to(self.device) L = torch.zeros(B, self.n_heads, T).to(self.device) M = torch.full((B, self.n_heads, T), -float('inf')).to(self.device) for i in range(0, T, self.block_size_q): Q_block = Q[:, :, i:i+self.block_size_q] O_block = torch.zeros_like(Q_block).to(self.device) L_block = torch.zeros(B, self.n_heads, Q_block.size(2)).to(self.device) M_block = torch.full((B, self.n_heads, Q_block.size(2)), -float('inf')).to(self.device) for j in range(0, T, self.block_size_kv): K_block = K[:, :, j:j+self.block_size_kv] V_block = V[:, :, j:j+self.block_size_kv] S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) M_block_old = M_block M_block = torch.max(M_block, S_block.max(dim=-1).values) exp_S_block = torch.exp(S_block - M_block.unsqueeze(-1)) L_block = torch.exp(M_block_old - M_block) * L_block + exp_S_block.sum(dim=-1) O_block += torch.matmul(exp_S_block, V_block) O_block /= L_block.unsqueeze(-1) O[:, :, i:i+self.block_size_q] = O_block O = O.transpose(1, 2).contiguous().view(B, T, self.n_heads * (C // self.n_heads)) O = self.out_proj(O) return O # Define the MLP module class MLP(nn.Module): def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) self.gelu = nn.GELU(approximate='tanh') self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) self.dropout = nn.Dropout(config.dropout) self.c_proj.scale_init = 1 def forward(self, x): x = self.c_fc(x) x = self.gelu(x) x = self.c_proj(x) x = self.dropout(x) return x # Define the MixtureOfExperts module class MixtureOfExperts(nn.Module): def __init__(self, config, num_experts, expert_layers): super().__init__() self.num_experts = num_experts self.expert_layers = expert_layers self.experts = nn.ModuleList([self._create_expert(config) for _ in range(num_experts)]) self.gate = nn.Linear(config.n_embd, num_experts) def _create_expert(self, config): layers = [] for _ in range(self.expert_layers): layers.append(FlashAttention3(d_model=config.n_embd, n_heads=config.n_head, block_size_q=32, block_size_kv=32, num_blocks_kv=4)) layers.append(nn.LayerNorm(config.n_embd)) layers.append(MLP(config)) return nn.Sequential(*layers) def forward(self, x): B, T, C = x.size() gate_scores = self.gate(x) gate_probs = F.softmax(gate_scores, dim=-1) expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) gate_probs = gate_probs.unsqueeze(-1) gate_probs = gate_probs.permute(0, 2, 1, 3) output = torch.sum(gate_probs * expert_outputs, dim=1) return output # Define the BlockWithMoE module class BlockWithMoE(nn.Module): def __init__(self, config, num_experts=4, expert_layers=2, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device='cuda'): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd) self.attn = FlashAttention3(d_model=config.n_embd, n_heads=config.n_head, block_size_q=block_size_q, block_size_kv=block_size_kv, num_blocks_kv=num_blocks_kv, device=device) self.dropout1 = nn.Dropout(config.dropout) self.ln_2 = nn.LayerNorm(config.n_embd) self.moe = MixtureOfExperts(config, num_experts, expert_layers) self.dropout2 = nn.Dropout(config.dropout) self.ln_3 = nn.LayerNorm(config.n_embd) self.mlp = MLP(config) self.dropout3 = nn.Dropout(config.dropout) def forward(self, x): B, T, C = x.size() attn_output = self.attn(x) x = x + attn_output x = self.dropout1(x) x = x + self.moe(self.ln_2(x)) x = self.dropout2(x) x = x + self.mlp(self.ln_3(x)) x = self.dropout3(x) return x # Define the GPT configuration dataclass @dataclass class GPTConfig: block_size: int = 512 vocab_size: int = 50257 n_layer: int = 6 n_head: int = 4 n_embd: int = 256 dropout: float = 0.2 # Define the GPTWithMoE model class GPTWithMoE(nn.Module): def __init__(self, config, num_experts=2, expert_layers=2, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device='cuda'): super().__init__() self.config = config self.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_size, config.n_embd), wpe=nn.Embedding(config.block_size, config.n_embd), h=nn.ModuleList([BlockWithMoE(config, num_experts, expert_layers, block_size_q, block_size_kv, num_blocks_kv, device) for _ in range(config.n_layer)]), ln_f=nn.LayerNorm(config.n_embd), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.transformer.wte.weight = self.lm_head.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): std = 0.02 if hasattr(module, 'scale_init'): std *= (2 * self.config.n_layer) ** -0.5 torch.nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.size() assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}" pos = torch.arange(0, T, dtype=torch.long, device=idx.device) pos_emb = self.transformer.wpe(pos) tok_emb = self.transformer.wte(idx) x = tok_emb + pos_emb for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss def configure_optimizers(self, weight_decay, learning_rate, device): param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] non_decay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': non_decay_params, 'weight_decay': 0} ] fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters use_fused = fused_available and 'cuda' in device print(f" Using fused AdamW: {use_fused}") optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused) return optimizer # MCTS Implementation @dataclass class MCTSNode: state: torch.Tensor parent: 'MCTSNode' = None children: dict = None visits: int = 0 value: float = 0.0 def __post_init__(self): if self.children is None: self.children = {} # Define scriptable functions separately def select_node(node: MCTSNode, c_puct: float) -> MCTSNode: if not node.children: return node scores = torch.tensor([ child.value / (child.visits + 1e-8) + c_puct * math.sqrt(math.log(node.visits + 1) / (child.visits + 1e-8)) for child in node.children.values() ]) best_child_idx = torch.argmax(scores).item() return list(node.children.values())[best_child_idx] def expand_node(node: MCTSNode, logits: torch.Tensor, top_k: int) -> None: probs = F.softmax(logits, dim=-1) top_k_probs, top_k_indices = torch.topk(probs, k=top_k) for prob, token in zip(top_k_probs, top_k_indices): if token.item() not in node.children: node.children[token.item()] = MCTSNode(state=token, parent=node) def simulate(model: torch.nn.Module, sequence: torch.Tensor, max_length: int) -> torch.Tensor: # Ensure sequence is 2D if sequence.dim() == 1: sequence = sequence.unsqueeze(0) with torch.no_grad(): while sequence.size(1) < max_length: with autocast(): logits, _ = model(sequence) probs = F.softmax(logits[0, -1], dim=-1) next_token = torch.multinomial(probs, 1) sequence = torch.cat([sequence, next_token.unsqueeze(0)], dim=1) return sequence.squeeze(0) def backpropagate(node: MCTSNode, value: float) -> None: while node is not None: node.visits += 1 node.value += value node = node.parent def mcts_decode_single(model: torch.nn.Module, input_ids: torch.Tensor, max_length: int, num_simulations: int, c_puct: float, top_k: int) -> torch.Tensor: # Ensure input_ids is 2D if input_ids.dim() == 1: input_ids = input_ids.unsqueeze(0) root = MCTSNode(state=input_ids) for _ in range(num_simulations): node = root current_input = input_ids.clone() # Selection while node.children and current_input.size(1) < max_length: node = select_node(node, c_puct) current_input = torch.cat([current_input, node.state.unsqueeze(0).unsqueeze(0)], dim=1) # Expansion if current_input.size(1) < max_length: with torch.no_grad(): with autocast(): logits, _ = model(current_input) expand_node(node, logits[0, -1], top_k) # Simulation simulation_sequence = simulate(model, current_input.squeeze(0), max_length) # Evaluation with torch.no_grad(): with autocast(): _, loss = model(simulation_sequence.unsqueeze(0), simulation_sequence.unsqueeze(0)) value = -loss.item() # Backpropagation backpropagate(node, value) # Choose the best next token best_child = max(root.children.values(), key=lambda n: n.visits) result = torch.cat([input_ids.squeeze(0), best_child.state.unsqueeze(0)], dim=0) # Ensure the result doesn't exceed max_length return result[:max_length] def mcts_decode_batch(model: torch.nn.Module, input_ids_list: List[torch.Tensor], max_length: int, num_simulations: int, c_puct: float, top_k: int) -> List[torch.Tensor]: return [mcts_decode_single(model, input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids, max_length, num_simulations, c_puct, top_k) for input_ids in input_ids_list] def validate_with_mcts(model: torch.nn.Module, val_dataloader: CustomDataLoaderLite, device: torch.device, max_length: int, num_simulations: int, c_puct: float, top_k: int) -> float: model.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for x, y in val_dataloader: x, y = x.to(device), y.to(device) # Use MCTS for decoding decoded_sequences = mcts_decode_batch(model, x, max_length, num_simulations, c_puct, top_k) # Pad sequences to the same length decoded_sequences_padded = pad_sequence(decoded_sequences, batch_first=True, padding_value=0) # Trim the decoded sequences to match the target length decoded_sequences_trimmed = decoded_sequences_padded[:, :y.size(1)] # Calculate loss using the MCTS-decoded sequences with autocast(): logits, loss = model(decoded_sequences_trimmed, y) total_loss += loss.item() num_batches += 1 return total_loss / num_batches if num_batches > 0 else 0.0 def train_model(): device = 'cpu' if torch.cuda.is_available(): device = 'cuda' elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = 'mps' print(f"using device : {device}") # Load the dataset and create the data loader print("Loading datasets...") train_dataset = NpyDataset('edu_fineweb10B', 'edufineweb_train') val_dataset = NpyDataset('edu_fineweb10B', 'edufineweb_val') train_dataloader = CustomDataLoaderLite(train_dataset, batch_size=12, seq_len=512) val_dataloader = CustomDataLoaderLite(val_dataset, batch_size=12, seq_len=512) # Training loop max_steps = 200 total_batch_size = 262144 B = 12 T = 512 grad_accum_steps = total_batch_size // (B * T) # Set up the configuration print("Setting up model configuration...") config = GPTConfig(vocab_size=50304, block_size=512, n_layer=6, n_head=4, n_embd=256) # Initialize the model print("Initializing model...") model = GPTWithMoE(config, num_experts=3, expert_layers=3, block_size_q=32, block_size_kv=32, num_blocks_kv=4, device=device) model.to(device) # Load the saved model weights if they exist save_path = "C:\\Users\\Admin\\MODELS\\moe_mcts_new.pt" temp_save_path = "C:\\Users\\Admin\\MODELS\\moe_mcts_temp_new.pt" if os.path.isfile(save_path): print(f"Loading model weights from {save_path}...") model.load_state_dict(torch.load(save_path)) print(f"Loaded model weights from {save_path}") print("Configuring optimizer...") optimizer = model.configure_optimizers(weight_decay=0.2, learning_rate=3e-3, device=device) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True) train_losses = [] val_losses = [] scaler = torch.cuda.amp.GradScaler() for i in range(max_steps): t0 = time.time() optimizer.zero_grad() train_loss_accum = 0 model.train() print(f"Training step {i + 1}/{max_steps}...") for x, y in train_dataloader: x, y = x.to(device), y.to(device) with torch.cuda.amp.autocast(): logits, loss = model(x, y) loss = loss / grad_accum_steps train_loss_accum += loss.detach() scaler.scale(loss).backward() norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() torch.cuda.synchronize() t1 = time.time() dt = (t1 - t0) * 1000 tokens_per_sec = (B * T * grad_accum_steps) / (t1 - t0) train_losses.append(train_loss_accum.item()) torch.cuda.empty_cache() # Validation with MCTS model.eval() val_loss = validate_with_mcts(model, val_dataloader, device, max_length=T, num_simulations=100, c_puct=1.0, top_k=10) val_losses.append(val_loss) scheduler.step(val_loss) print(f"step {i} | train loss: {train_loss_accum.item():.6f} | val loss: {val_loss:.6f} | lr: {optimizer.param_groups[0]['lr']:.8f} | norm: {norm:.4f} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec}") # Save model weights torch.save(model.state_dict(), temp_save_path) os.replace(temp_save_path, save_path) print(f"Model saved at step {i+1} to {save_path}") # Plotting the training and validation loss plt.figure(figsize=(10, 5)) plt.plot(train_losses, label='Training Loss') plt.plot(val_losses, label='Validation Loss') plt.xlabel('Steps') plt.ylabel('Loss') plt.legend() plt.show() if __name__ == "__main__": train_model()