| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR |
| from torch.cuda.amp import GradScaler, autocast |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| import os |
| import logging |
| from tqdm import tqdm |
| import wandb |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class AdvancedTrainer: |
| """ |
| Advanced training framework with mixed precision, distributed training, |
| and modern optimization techniques. |
| """ |
|
|
| def __init__(self, model, train_dataset, val_dataset, config): |
| self.config = config |
| self.model = model |
| self.train_dataset = train_dataset |
| self.val_dataset = val_dataset |
|
|
| |
| self.world_size = int(os.environ.get('WORLD_SIZE', 1)) |
| self.rank = int(os.environ.get('RANK', 0)) |
| self.local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
|
|
| self.is_distributed = self.world_size > 1 |
| self.is_main_process = self.rank == 0 |
|
|
| if self.is_distributed: |
| self._setup_distributed() |
|
|
| |
| self.scaler = GradScaler() if config.use_mixed_precision else None |
|
|
| |
| self.optimizer = self._create_optimizer() |
| self.scheduler = self._create_scheduler() |
|
|
| |
| self.criterion = { |
| 'emotion': nn.CrossEntropyLoss(label_smoothing=0.1), |
| 'intent': nn.CrossEntropyLoss(label_smoothing=0.1), |
| 'engagement': self._create_regression_loss(), |
| 'confidence': self._create_regression_loss(), |
| 'contrastive': nn.CrossEntropyLoss() |
| } |
|
|
| |
| self.task_weights = config.task_weights |
|
|
| |
| if self.is_main_process and config.use_wandb: |
| wandb.init(project="emotia-training", config=config.__dict__) |
|
|
| def _setup_distributed(self): |
| """Setup distributed training""" |
| torch.cuda.set_device(self.local_rank) |
| dist.init_process_group( |
| backend='nccl', |
| init_method='env://', |
| world_size=self.world_size, |
| rank=self.rank |
| ) |
|
|
| |
| self.model = DDP(self.model, device_ids=[self.local_rank]) |
|
|
| def _create_optimizer(self): |
| """Create advanced optimizer""" |
| if self.config.optimizer == 'adamw': |
| optimizer = optim.AdamW( |
| self.model.parameters(), |
| lr=self.config.lr, |
| weight_decay=self.config.weight_decay, |
| betas=(0.9, 0.999) |
| ) |
| elif self.config.optimizer == 'lion': |
| |
| from lion_pytorch import Lion |
| optimizer = Lion( |
| self.model.parameters(), |
| lr=self.config.lr, |
| weight_decay=self.config.weight_decay |
| ) |
| else: |
| optimizer = optim.Adam( |
| self.model.parameters(), |
| lr=self.config.lr, |
| weight_decay=self.config.weight_decay |
| ) |
|
|
| return optimizer |
|
|
| def _create_scheduler(self): |
| """Create advanced learning rate scheduler""" |
| if self.config.scheduler == 'cosine': |
| scheduler = CosineAnnealingLR( |
| self.optimizer, |
| T_max=self.config.epochs, |
| eta_min=self.config.min_lr |
| ) |
| elif self.config.scheduler == 'one_cycle': |
| scheduler = OneCycleLR( |
| self.optimizer, |
| max_lr=self.config.lr, |
| epochs=self.config.epochs, |
| steps_per_epoch=len(self.train_dataset) // (self.config.batch_size * self.world_size), |
| pct_start=0.3, |
| anneal_strategy='cos' |
| ) |
| else: |
| scheduler = None |
|
|
| return scheduler |
|
|
| def _create_regression_loss(self): |
| """Create regression loss with uncertainty""" |
| def uncertainty_loss(pred_mean, pred_var, target): |
| |
| loss = 0.5 * torch.log(pred_var) + 0.5 * (target - pred_mean)**2 / pred_var |
| return loss.mean() |
|
|
| return uncertainty_loss |
|
|
| def train_epoch(self, epoch): |
| """Train for one epoch with advanced techniques""" |
| self.model.train() |
|
|
| if self.is_distributed: |
| sampler = DistributedSampler(self.train_dataset, shuffle=True) |
| dataloader = torch.utils.data.DataLoader( |
| self.train_dataset, |
| batch_size=self.config.batch_size, |
| sampler=sampler, |
| num_workers=self.config.num_workers, |
| pin_memory=True |
| ) |
| else: |
| dataloader = torch.utils.data.DataLoader( |
| self.train_dataset, |
| batch_size=self.config.batch_size, |
| shuffle=True, |
| num_workers=self.config.num_workers, |
| pin_memory=True |
| ) |
|
|
| total_loss = 0 |
| num_batches = 0 |
|
|
| progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}") if self.is_main_process else dataloader |
|
|
| for batch in progress_bar: |
| |
| batch = {k: v.cuda(self.local_rank) if torch.is_tensor(v) else v for k, v in batch.items()} |
|
|
| self.optimizer.zero_grad() |
|
|
| |
| if self.scaler: |
| with autocast(): |
| outputs = self.model(**batch) |
| loss = self._compute_loss(outputs, batch) |
| self.scaler.scale(loss).backward() |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| outputs = self.model(**batch) |
| loss = self._compute_loss(outputs, batch) |
| loss.backward() |
| self.optimizer.step() |
|
|
| |
| if isinstance(self.scheduler, OneCycleLR): |
| self.scheduler.step() |
|
|
| total_loss += loss.item() |
| num_batches += 1 |
|
|
| |
| if self.is_main_process: |
| progress_bar.set_postfix({'loss': f'{loss.item():.4f}'}) |
|
|
| avg_loss = total_loss / num_batches |
|
|
| |
| if isinstance(self.scheduler, CosineAnnealingLR): |
| self.scheduler.step() |
|
|
| return avg_loss |
|
|
| def _compute_loss(self, outputs, batch): |
| """Compute multi-task loss with uncertainty""" |
| total_loss = 0 |
|
|
| |
| if 'emotion_logits' in outputs and 'emotion' in batch: |
| emotion_loss = self.criterion['emotion'](outputs['emotion_logits'], batch['emotion']) |
| total_loss += self.task_weights['emotion'] * emotion_loss |
|
|
| |
| if 'intent_logits' in outputs and 'intent' in batch: |
| intent_loss = self.criterion['intent'](outputs['intent_logits'], batch['intent']) |
| total_loss += self.task_weights['intent'] * intent_loss |
|
|
| |
| if 'engagement_mean' in outputs and 'engagement_var' in outputs and 'engagement' in batch: |
| engagement_loss = self.criterion['engagement']( |
| outputs['engagement_mean'], outputs['engagement_var'], batch['engagement'] |
| ) |
| total_loss += self.task_weights['engagement'] * engagement_loss |
|
|
| |
| if 'confidence_mean' in outputs and 'confidence_var' in outputs and 'confidence' in batch: |
| confidence_loss = self.criterion['confidence']( |
| outputs['confidence_mean'], outputs['confidence_var'], batch['confidence'] |
| ) |
| total_loss += self.task_weights['confidence'] * confidence_loss |
|
|
| |
| if hasattr(self.model, 'contrastive_loss') and 'embeddings' in outputs: |
| contrastive_loss = self.model.contrastive_loss(outputs['embeddings']) |
| total_loss += self.config.contrastive_weight * contrastive_loss |
|
|
| return total_loss |
|
|
| def validate(self, epoch): |
| """Validation with comprehensive metrics""" |
| self.model.eval() |
|
|
| if self.is_distributed: |
| sampler = DistributedSampler(self.val_dataset, shuffle=False) |
| dataloader = torch.utils.data.DataLoader( |
| self.val_dataset, |
| batch_size=self.config.batch_size, |
| sampler=sampler, |
| num_workers=self.config.num_workers, |
| pin_memory=True |
| ) |
| else: |
| dataloader = torch.utils.data.DataLoader( |
| self.val_dataset, |
| batch_size=self.config.batch_size, |
| shuffle=False, |
| num_workers=self.config.num_workers, |
| pin_memory=True |
| ) |
|
|
| total_loss = 0 |
| num_batches = 0 |
|
|
| all_emotion_preds = [] |
| all_emotion_labels = [] |
| all_intent_preds = [] |
| all_intent_labels = [] |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| batch = {k: v.cuda(self.local_rank) if torch.is_tensor(v) else v for k, v in batch.items()} |
|
|
| outputs = self.model(**batch) |
| loss = self._compute_loss(outputs, batch) |
|
|
| total_loss += loss.item() |
| num_batches += 1 |
|
|
| |
| if 'emotion_logits' in outputs: |
| all_emotion_preds.extend(outputs['emotion_logits'].argmax(dim=1).cpu().numpy()) |
| all_emotion_labels.extend(batch['emotion'].cpu().numpy()) |
|
|
| if 'intent_logits' in outputs: |
| all_intent_preds.extend(outputs['intent_logits'].argmax(dim=1).cpu().numpy()) |
| all_intent_labels.extend(batch['intent'].cpu().numpy()) |
|
|
| avg_loss = total_loss / num_batches |
|
|
| |
| metrics = self._compute_metrics(all_emotion_preds, all_emotion_labels, |
| all_intent_preds, all_intent_labels) |
|
|
| return avg_loss, metrics |
|
|
| def _compute_metrics(self, emotion_preds, emotion_labels, intent_preds, intent_labels): |
| """Compute comprehensive evaluation metrics""" |
| from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support |
|
|
| metrics = {} |
|
|
| if emotion_preds and emotion_labels: |
| metrics.update({ |
| 'emotion_accuracy': accuracy_score(emotion_labels, emotion_preds), |
| 'emotion_f1_macro': f1_score(emotion_labels, emotion_preds, average='macro'), |
| 'emotion_f1_weighted': f1_score(emotion_labels, emotion_preds, average='weighted'), |
| }) |
|
|
| if intent_preds and intent_labels: |
| metrics.update({ |
| 'intent_accuracy': accuracy_score(intent_labels, intent_preds), |
| 'intent_f1_macro': f1_score(intent_labels, intent_preds, average='macro'), |
| 'intent_f1_weighted': f1_score(intent_labels, intent_preds, average='weighted'), |
| }) |
|
|
| return metrics |
|
|
| def train(self): |
| """Main training loop""" |
| best_val_loss = float('inf') |
| patience_counter = 0 |
|
|
| for epoch in range(self.config.epochs): |
| |
| train_loss = self.train_epoch(epoch) |
|
|
| |
| val_loss, val_metrics = self.validate(epoch) |
|
|
| |
| if self.is_main_process: |
| logger.info(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}") |
| for metric_name, metric_value in val_metrics.items(): |
| logger.info(f"{metric_name}: {metric_value:.4f}") |
|
|
| |
| if self.config.use_wandb: |
| wandb.log({ |
| 'epoch': epoch, |
| 'train_loss': train_loss, |
| 'val_loss': val_loss, |
| **val_metrics, |
| 'lr': self.optimizer.param_groups[0]['lr'] |
| }) |
|
|
| |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| patience_counter = 0 |
| if self.is_main_process: |
| self.save_checkpoint(epoch, val_loss, val_metrics) |
| else: |
| patience_counter += 1 |
|
|
| |
| if patience_counter >= self.config.patience: |
| logger.info("Early stopping triggered") |
| break |
|
|
| |
| if self.is_distributed: |
| dist.destroy_process_group() |
|
|
| def save_checkpoint(self, epoch, val_loss, val_metrics): |
| """Save model checkpoint""" |
| checkpoint = { |
| 'epoch': epoch, |
| 'model_state_dict': self.model.state_dict(), |
| 'optimizer_state_dict': self.optimizer.state_dict(), |
| 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, |
| 'scaler_state_dict': self.scaler.state_dict() if self.scaler else None, |
| 'val_loss': val_loss, |
| 'val_metrics': val_metrics, |
| 'config': self.config |
| } |
|
|
| checkpoint_path = f"{self.config.checkpoint_dir}/checkpoint_epoch_{epoch}.pth" |
| torch.save(checkpoint, checkpoint_path) |
| logger.info(f"Saved checkpoint: {checkpoint_path}") |
|
|
| @staticmethod |
| def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None, scaler=None): |
| """Load model checkpoint""" |
| checkpoint = torch.load(checkpoint_path) |
| model.load_state_dict(checkpoint['model_state_dict']) |
|
|
| if optimizer and 'optimizer_state_dict' in checkpoint: |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
| if scheduler and 'scheduler_state_dict' in checkpoint: |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
| if scaler and 'scaler_state_dict' in checkpoint: |
| scaler.load_state_dict(checkpoint['scaler_state_dict']) |
|
|
| return checkpoint['epoch'], checkpoint['val_loss'], checkpoint['val_metrics'] |