| | """ |
| | Artist Style Embedding - Trainer |
| | """ |
| | from pathlib import Path |
| | from typing import Dict |
| | from collections import defaultdict |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.cuda.amp import GradScaler, autocast |
| | from torch.optim import AdamW |
| | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR |
| | from tqdm import tqdm |
| | import numpy as np |
| |
|
| | try: |
| | import wandb |
| | WANDB_AVAILABLE = True |
| | except ImportError: |
| | WANDB_AVAILABLE = False |
| |
|
| |
|
| | class AverageMeter: |
| | def __init__(self): |
| | self.reset() |
| | |
| | def reset(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| | |
| | def update(self, val, n=1): |
| | self.val = val |
| | self.sum += val * n |
| | self.count += n |
| | self.avg = self.sum / self.count |
| |
|
| |
|
| | class Trainer: |
| | def __init__(self, model, loss_fn, train_loader, val_loader, config, artist_to_idx): |
| | self.model = model |
| | self.loss_fn = loss_fn |
| | self.train_loader = train_loader |
| | self.val_loader = val_loader |
| | self.config = config |
| | self.artist_to_idx = artist_to_idx |
| | self.idx_to_artist = {v: k for k, v in artist_to_idx.items()} |
| | |
| | self.device = torch.device(config.train.device) |
| | self.model = self.model.to(self.device) |
| | self.loss_fn = self.loss_fn.to(self.device) |
| | |
| | self.optimizer = self._create_optimizer() |
| | self.scheduler = self._create_scheduler() |
| | |
| | self.use_amp = config.train.use_amp |
| | self.scaler = GradScaler() if self.use_amp else None |
| | |
| | self.save_dir = Path(config.train.save_dir) |
| | self.save_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | self.current_epoch = 0 |
| | self.global_step = 0 |
| | self.best_metric = 0.0 |
| | self.patience_counter = 0 |
| | |
| | self.use_wandb = WANDB_AVAILABLE and config.train.wandb_project |
| | if self.use_wandb: |
| | wandb.init( |
| | project=config.train.wandb_project, |
| | name=config.train.wandb_run_name, |
| | config={'model': config.model.__dict__, 'train': config.train.__dict__} |
| | ) |
| | |
| | def _create_optimizer(self): |
| | backbone_params = self.model.encoder.get_backbone_params() |
| | head_params = self.model.encoder.get_head_params() |
| | arcface_params = [self.model.arcface_weight] |
| | loss_params = list(self.loss_fn.center_loss.parameters()) |
| | |
| | return AdamW([ |
| | {'params': backbone_params, 'lr': self.config.train.learning_rate * self.config.train.backbone_lr_multiplier}, |
| | {'params': head_params, 'lr': self.config.train.learning_rate}, |
| | {'params': arcface_params, 'lr': self.config.train.learning_rate}, |
| | {'params': loss_params, 'lr': self.config.train.learning_rate * 0.5}, |
| | ], weight_decay=self.config.train.weight_decay) |
| | |
| | def _create_scheduler(self): |
| | warmup = LinearLR(self.optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config.train.warmup_epochs) |
| | main = CosineAnnealingWarmRestarts(self.optimizer, T_0=self.config.train.epochs - self.config.train.warmup_epochs, eta_min=self.config.train.min_lr) |
| | return SequentialLR(self.optimizer, [warmup, main], milestones=[self.config.train.warmup_epochs]) |
| | |
| | def train_epoch(self) -> Dict[str, float]: |
| | self.model.train() |
| | loss_meters = defaultdict(AverageMeter) |
| | |
| | if self.current_epoch < self.config.model.freeze_backbone_epochs: |
| | self.model.encoder.freeze_backbone() |
| | else: |
| | self.model.encoder.unfreeze_backbone() |
| | |
| | pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}") |
| | |
| | for batch in pbar: |
| | full = batch['full'].to(self.device) |
| | face = batch['face'].to(self.device) |
| | eye = batch['eye'].to(self.device) |
| | has_face = batch['has_face'].to(self.device) |
| | has_eye = batch['has_eye'].to(self.device) |
| | labels = batch['label'].to(self.device) |
| | |
| | with autocast(enabled=self.use_amp): |
| | output = self.model(full, face, eye, has_face, has_eye) |
| | loss, loss_dict = self.loss_fn(output['embeddings'], output['cosine'], labels) |
| | |
| | self.optimizer.zero_grad() |
| | if self.use_amp: |
| | self.scaler.scale(loss).backward() |
| | self.scaler.unscale_(self.optimizer) |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.train.max_grad_norm) |
| | self.scaler.step(self.optimizer) |
| | self.scaler.update() |
| | else: |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.train.max_grad_norm) |
| | self.optimizer.step() |
| | |
| | for k, v in loss_dict.items(): |
| | loss_meters[k].update(v, full.size(0)) |
| | |
| | pbar.set_postfix({'loss': f"{loss_meters['loss_total'].avg:.4f}"}) |
| | |
| | self.global_step += 1 |
| | if self.global_step % self.config.train.log_every_n_steps == 0 and self.use_wandb: |
| | wandb.log({f"train/{k}": v.avg for k, v in loss_meters.items()}, step=self.global_step) |
| | |
| | return {k: v.avg for k, v in loss_meters.items()} |
| | |
| | @torch.no_grad() |
| | def validate(self) -> Dict[str, float]: |
| | self.model.eval() |
| | |
| | total_correct = 0 |
| | total_samples = 0 |
| | total_correct_top5 = 0 |
| | loss_meters = defaultdict(AverageMeter) |
| | |
| | for batch in tqdm(self.val_loader, desc="Validation"): |
| | full = batch['full'].to(self.device) |
| | face = batch['face'].to(self.device) |
| | eye = batch['eye'].to(self.device) |
| | has_face = batch['has_face'].to(self.device) |
| | has_eye = batch['has_eye'].to(self.device) |
| | labels = batch['label'].to(self.device) |
| | |
| | with autocast(enabled=self.use_amp): |
| | output = self.model(full, face, eye, has_face, has_eye) |
| | loss, loss_dict = self.loss_fn(output['embeddings'], output['cosine'], labels) |
| | |
| | |
| | preds = output['cosine'].argmax(dim=1) |
| | total_correct += (preds == labels).sum().item() |
| | |
| | |
| | _, top5_preds = output['cosine'].topk(5, dim=1) |
| | top5_correct = top5_preds.eq(labels.view(-1, 1).expand_as(top5_preds)) |
| | total_correct_top5 += top5_correct.any(dim=1).sum().item() |
| | |
| | total_samples += labels.size(0) |
| | |
| | for k, v in loss_dict.items(): |
| | loss_meters[k].update(v, full.size(0)) |
| | |
| | accuracy = total_correct / total_samples if total_samples > 0 else 0 |
| | accuracy_top5 = total_correct_top5 / total_samples if total_samples > 0 else 0 |
| | |
| | metrics = { |
| | 'accuracy': accuracy, |
| | 'accuracy_top5': accuracy_top5, |
| | } |
| | metrics.update({k: v.avg for k, v in loss_meters.items()}) |
| | |
| | if self.use_wandb: |
| | wandb.log({f"val/{k}": v for k, v in metrics.items()}, step=self.global_step) |
| | |
| | return metrics |
| | |
| | def save_checkpoint(self, filename: str, is_best: bool = False): |
| | checkpoint = { |
| | 'epoch': self.current_epoch, |
| | 'global_step': self.global_step, |
| | 'model_state_dict': self.model.state_dict(), |
| | 'optimizer_state_dict': self.optimizer.state_dict(), |
| | 'scheduler_state_dict': self.scheduler.state_dict(), |
| | 'best_metric': self.best_metric, |
| | 'config': {'model': self.config.model.__dict__, 'train': self.config.train.__dict__}, |
| | 'artist_to_idx': self.artist_to_idx, |
| | } |
| | if self.use_amp: |
| | checkpoint['scaler_state_dict'] = self.scaler.state_dict() |
| | |
| | torch.save(checkpoint, self.save_dir / filename) |
| | if is_best: |
| | torch.save(checkpoint, self.save_dir / 'best_model.pt') |
| | |
| | def load_checkpoint(self, path: str): |
| | checkpoint = torch.load(path, map_location=self.device) |
| | self.model.load_state_dict(checkpoint['model_state_dict']) |
| | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| | self.current_epoch = checkpoint['epoch'] |
| | self.global_step = checkpoint['global_step'] |
| | self.best_metric = checkpoint['best_metric'] |
| | if self.use_amp and 'scaler_state_dict' in checkpoint: |
| | self.scaler.load_state_dict(checkpoint['scaler_state_dict']) |
| | print(f"Loaded checkpoint from epoch {self.current_epoch}") |
| | |
| | def train(self): |
| | print(f"Training for {self.config.train.epochs} epochs on {self.device}") |
| | print(f"Artists: {len(self.artist_to_idx)}") |
| | |
| | for epoch in range(self.current_epoch, self.config.train.epochs): |
| | self.current_epoch = epoch |
| | |
| | train_metrics = self.train_epoch() |
| | print(f"\nEpoch {epoch} - Train Loss: {train_metrics['loss_total']:.4f}") |
| | |
| | val_metrics = self.validate() |
| | print(f"Epoch {epoch} - Val Loss: {val_metrics['loss_total']:.4f}, " |
| | f"Acc: {val_metrics['accuracy']:.4f}, " |
| | f"Top5: {val_metrics['accuracy_top5']:.4f}") |
| | |
| | self.scheduler.step() |
| | |
| | |
| | is_best = val_metrics['accuracy'] > self.best_metric |
| | if is_best: |
| | self.best_metric = val_metrics['accuracy'] |
| | self.patience_counter = 0 |
| | else: |
| | self.patience_counter += 1 |
| | |
| | if (epoch + 1) % self.config.train.save_every_n_epochs == 0: |
| | self.save_checkpoint(f'checkpoint_epoch_{epoch}.pt', is_best) |
| | elif is_best: |
| | self.save_checkpoint('best_model.pt', is_best=True) |
| | |
| | if self.patience_counter >= self.config.train.patience: |
| | print(f"Early stopping at epoch {epoch}") |
| | break |
| | |
| | self.save_checkpoint('final_model.pt') |
| | if self.use_wandb: |
| | wandb.finish() |
| | print(f"Training complete. Best Accuracy: {self.best_metric:.4f}") |