# ============================================================================= # training/trainer.py # ============================================================================= import torch import torch.nn as nn from torch.utils.data import DataLoader from typing import Dict, List, Optional import time import logging from pathlib import Path from core.config import MambaConfig from routing.tlm_manager import TLMManager from routing.aggregator import AttentionAggregator from training.optimizer import MambaOptimizer from training.loss import MambaLoss from training.data_loader import create_data_loaders from core.tokenizer import MambaTokenizer from core.preprocess import TextPreprocessor class MambaSwarmTrainer: """Multi-phase trainer for Mamba swarm architecture""" def __init__(self, config: MambaConfig): self.config = config self.device = config.device # Initialize components self.tokenizer = MambaTokenizer(config) self.preprocessor = TextPreprocessor(config) # Initialize TLM manager and aggregator self.tlm_manager = TLMManager(config) self.aggregator = AttentionAggregator(config) self.aggregator.to(self.device) # Initialize loss function self.loss_fn = MambaLoss(config, config.vocab_size) # Create data loaders self.data_loaders = create_data_loaders(config, self.tokenizer, self.preprocessor) # Training state self.global_step = 0 self.phase = "foundation" # foundation, specialists, aggregator, end_to_end # Setup logging self.setup_logging() def setup_logging(self): """Setup training logging""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('training.log'), logging.StreamHandler() ] ) self.logger = logging.getLogger(__name__) def train_foundation_phase(self, num_steps: int = 10000): """Phase 1: Train shared foundation weights""" self.logger.info("Starting foundation training phase...") self.phase = "foundation" # Get a reference specialist for foundation training reference_specialist = list(self.tlm_manager.specialists.values())[0] optimizer = MambaOptimizer(reference_specialist.model, self.config) reference_specialist.model.train() for step in range(num_steps): batch = next(iter(self.data_loaders['main'])) # Move to device input_ids = batch['input_ids'].to(self.device) target_ids = batch['target_ids'].to(self.device) # Forward pass logits, loss = reference_specialist.model(input_ids, target_ids) # Backward pass optimizer.zero_grad() loss.backward() lr = optimizer.step() self.global_step += 1 if step % 100 == 0: self.logger.info(f"Foundation step {step}, loss: {loss.item():.4f}, lr: {lr:.6f}") # Copy foundation weights to all specialists self._copy_foundation_weights(reference_specialist) self.logger.info("Foundation training phase completed!") def _copy_foundation_weights(self, reference_specialist): """Copy foundation weights to all specialists""" reference_state = reference_specialist.model.state_dict() for specialist in self.tlm_manager.specialists.values(): if specialist != reference_specialist: # Copy shared layers (first half of the model) specialist_state = specialist.model.state_dict() for name, param in reference_state.items(): if 'layers.' in name: # Extract layer number layer_num = int(name.split('.')[1]) if layer_num < self.config.n_layers // 2: # Share first half specialist_state[name] = param.clone() elif 'embedding' in name: # Share embeddings specialist_state[name] = param.clone() specialist.model.load_state_dict(specialist_state) def train_specialists_phase(self, num_steps: int = 5000): """Phase 2: Train domain specialists in parallel""" self.logger.info("Starting specialist training phase...") self.phase = "specialists" # Create optimizers for each specialist specialist_optimizers = {} for specialist_id, specialist in self.tlm_manager.specialists.items(): specialist_optimizers[specialist_id] = MambaOptimizer( specialist.model, self.config ) specialist.model.train() # Train specialists in parallel (simplified - could use actual parallel training) for step in range(num_steps): total_loss = 0.0 # Train each specialist on its domain data for specialist_id in range(min(10, self.config.num_specialists)): # Limit for demo if specialist_id in self.data_loaders['domains']: try: batch = next(iter(self.data_loaders['domains'][specialist_id])) # Move to device input_ids = batch['input_ids'].to(self.device) target_ids = batch['target_ids'].to(self.device) # Get specialist and optimizer specialist = self.tlm_manager.specialists[specialist_id] optimizer = specialist_optimizers[specialist_id] # Forward pass logits, loss = specialist.model(input_ids, target_ids) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() except Exception as e: self.logger.warning(f"Error training specialist {specialist_id}: {e}") continue self.global_step += 1 if step % 100 == 0: avg_loss = total_loss / min(10, self.config.num_specialists) self.logger.info(f"Specialists step {step}, avg loss: {avg_loss:.4f}") self.logger.info("Specialist training phase completed!") def train_aggregator_phase(self, num_steps: int = 3000): """Phase 3: Train aggregator to combine specialist outputs""" self.logger.info("Starting aggregator training phase...") self.phase = "aggregator" # Freeze specialist models for specialist in self.tlm_manager.specialists.values(): specialist.model.eval() for param in specialist.model.parameters(): param.requires_grad = False # Create optimizer for aggregator aggregator_optimizer = MambaOptimizer(self.aggregator, self.config) self.aggregator.train() for step in range(num_steps): try: batch = next(iter(self.data_loaders['main'])) # Simulate specialist outputs (simplified for demo) specialist_outputs = self._simulate_specialist_outputs(batch) # Get target text for comparison target_ids = batch['target_ids'].to(self.device) # Forward pass through aggregator logits = self.aggregator(specialist_outputs) # Compute loss loss_dict = self.loss_fn(logits, target_ids) loss = loss_dict['total_loss'] # Backward pass aggregator_optimizer.zero_grad() loss.backward() aggregator_optimizer.step() self.global_step += 1 if step % 100 == 0: self.logger.info(f"Aggregator step {step}, loss: {loss.item():.4f}") except Exception as e: self.logger.warning(f"Error in aggregator training step {step}: {e}") continue self.logger.info("Aggregator training phase completed!") def _simulate_specialist_outputs(self, batch) -> Dict[int, List[Dict]]: """Simulate specialist outputs for aggregator training""" # This is a simplified simulation - in real training, you'd run # the text through the router and specialists input_ids = batch['input_ids'].to(self.device) # Simulate 3 chunks with 2-3 specialists each specialist_outputs = {} for chunk_id in range(3): chunk_results = [] # Simulate 2-3 specialists working on this chunk for i in range(2 + chunk_id % 2): specialist_id = (chunk_id * 3 + i) % self.config.num_specialists if specialist_id in self.tlm_manager.specialists: specialist = self.tlm_manager.specialists[specialist_id] # Get encoding from specialist with torch.no_grad(): encoding = specialist.encode(input_ids[:1]) # Single sample chunk_results.append({ 'chunk_id': chunk_id, 'specialist_id': specialist_id, 'confidence': 0.8 + 0.2 * torch.rand(1).item(), 'encoding': encoding[0], # Remove batch dim 'domain': f'domain_{specialist_id}' }) specialist_outputs[chunk_id] = chunk_results return specialist_outputs def train_end_to_end_phase(self, num_steps: int = 2000): """Phase 4: End-to-end fine-tuning of the entire system""" self.logger.info("Starting end-to-end training phase...") self.phase = "end_to_end" # Unfreeze all parameters for specialist in self.tlm_manager.specialists.values(): specialist.model.train() for param in specialist.model.parameters(): param.requires_grad = True self.aggregator.train() # Create system-wide optimizer with lower learning rate all_params = [] # Add specialist parameters for specialist in self.tlm_manager.specialists.values(): all_params.extend(specialist.model.parameters()) # Add aggregator parameters all_params.extend(self.aggregator.parameters()) # Create optimizer with reduced learning rate end_to_end_config = self.config end_to_end_config.learning_rate = self.config.learning_rate * 0.1 system_optimizer = torch.optim.AdamW( all_params, lr=end_to_end_config.learning_rate, weight_decay=end_to_end_config.weight_decay ) for step in range(num_steps): try: batch = next(iter(self.data_loaders['main'])) # Full system forward pass (simplified) specialist_outputs = self._simulate_specialist_outputs(batch) logits = self.aggregator(specialist_outputs) # Compute loss target_ids = batch['target_ids'].to(self.device) loss_dict = self.loss_fn(logits, target_ids) loss = loss_dict['total_loss'] # Backward pass system_optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0) system_optimizer.step() self.global_step += 1 if step % 100 == 0: self.logger.info(f"End-to-end step {step}, loss: {loss.item():.4f}") except Exception as e: self.logger.warning(f"Error in end-to-end training step {step}: {e}") continue self.logger.info("End-to-end training phase completed!") def full_training_pipeline(self): """Run the complete 4-phase training pipeline""" self.logger.info("Starting full Mamba swarm training pipeline...") start_time = time.time() try: # Phase 1: Foundation training self.train_foundation_phase(num_steps=1000) # Reduced for demo # Phase 2: Specialist training self.train_specialists_phase(num_steps=500) # Reduced for demo # Phase 3: Aggregator training self.train_aggregator_phase(num_steps=300) # Reduced for demo # Phase 4: End-to-end fine-tuning self.train_end_to_end_phase(num_steps=200) # Reduced for demo total_time = time.time() - start_time self.logger.info(f"Training completed in {total_time:.2f} seconds!") except Exception as e: self.logger.error(f"Training failed: {e}") raise def save_checkpoint(self, checkpoint_path: str): """Save training checkpoint""" checkpoint = { 'global_step': self.global_step, 'phase': self.phase, 'config': self.config.__dict__, 'aggregator_state': self.aggregator.state_dict(), 'specialist_states': {} } # Save specialist states for specialist_id, specialist in self.tlm_manager.specialists.items(): checkpoint['specialist_states'][specialist_id] = specialist.model.state_dict() torch.save(checkpoint, checkpoint_path) self.logger.info(f"Checkpoint saved to {checkpoint_path}") def load_checkpoint(self, checkpoint_path: str): """Load training checkpoint""" checkpoint = torch.load(checkpoint_path, map_location=self.device) self.global_step = checkpoint['global_step'] self.phase = checkpoint['phase'] # Load aggregator state self.aggregator.load_state_dict(checkpoint['aggregator_state']) # Load specialist states for specialist_id, state_dict in checkpoint['specialist_states'].items(): if specialist_id in self.tlm_manager.specialists: self.tlm_manager.specialists[specialist_id].model.load_state_dict(state_dict) self.logger.info(f"Checkpoint loaded from {checkpoint_path}") def evaluate(self, eval_steps: int = 100) -> Dict[str, float]: """Evaluate the trained model""" self.logger.info("Starting evaluation...") # Set models to eval mode for specialist in self.tlm_manager.specialists.values(): specialist.model.eval() self.aggregator.eval() total_loss = 0.0 num_steps = 0 with torch.no_grad(): for step in range(eval_steps): try: batch = next(iter(self.data_loaders['main'])) # Forward pass specialist_outputs = self._simulate_specialist_outputs(batch) logits = self.aggregator(specialist_outputs) # Compute loss target_ids = batch['target_ids'].to(self.device) loss_dict = self.loss_fn(logits, target_ids) total_loss += loss_dict['total_loss'].item() num_steps += 1 except Exception as e: self.logger.warning(f"Error in evaluation step {step}: {e}") continue avg_loss = total_loss / max(num_steps, 1) perplexity = torch.exp(torch.tensor(avg_loss)).item() results = { 'eval_loss': avg_loss, 'perplexity': perplexity, 'num_steps': num_steps } self.logger.info(f"Evaluation results: {results}") return results