import os import json import torch import wandb import datetime import numpy as np from tqdm import tqdm from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from segmentation_models_pytorch.base.modules import Activation from SemanticModel.data_loader import SegmentationDataset from SemanticModel.metrics import compute_mean_iou from SemanticModel.image_preprocessing import get_training_augmentations, get_validation_augmentations from SemanticModel.utilities import list_images, validate_dimensions class ModelTrainer: def __init__(self, model_config, root_dir, epochs=40, train_size=1024, val_size=None, workers=2, batch_size=2, learning_rate=1e-4, step_count=2, decay_factor=0.8, wandb_config=None, optimizer='rmsprop', target_class=None, resume_path=None): self.config = model_config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.root_dir = root_dir self._initialize_training_params(epochs, train_size, val_size, workers, batch_size, learning_rate, step_count, decay_factor, optimizer, target_class) self._setup_directories() self._initialize_datasets() self._setup_optimizer() self._initialize_tracking() if resume_path: self._resume_training(resume_path) def _initialize_training_params(self, epochs, train_size, val_size, workers, batch_size, learning_rate, step_count, decay_factor, optimizer, target_class): self.epochs = epochs self.train_size = train_size self.val_size = val_size self.workers = workers self.batch_size = batch_size self.learning_rate = learning_rate self.step_schedule = self._calculate_step_schedule(epochs, step_count) self.decay_factor = decay_factor self.optimizer_type = optimizer self.target_class = target_class self.current_epoch = 1 self.best_iou = 0.0 self.best_epoch = 0 self.classes = ['background'] + self.config.classes if self.config.background_flag else self.config.classes def _setup_directories(self): """Verifies and creates necessary directories.""" self.train_dir = os.path.join(self.root_dir, 'train') self.val_dir = os.path.join(self.root_dir, 'val') required_subdirs = ['Images', 'Masks'] for path in [self.train_dir] + ([self.val_dir] if os.path.exists(self.val_dir) else []): for subdir in required_subdirs: full_path = os.path.join(path, subdir) if not os.path.exists(full_path): raise FileNotFoundError(f"Missing directory: {full_path}") def _initialize_datasets(self): """Sets up training and validation datasets.""" self.train_dataset = SegmentationDataset( self.train_dir, classes=self.classes, augmentation=get_training_augmentations(self.train_size, self.train_size), preprocessing=self.config.preprocessing ) if os.path.exists(self.val_dir): self.val_dataset = SegmentationDataset( self.val_dir, classes=self.classes, augmentation=get_validation_augmentations( self.val_size or self.train_size, self.val_size or self.train_size, fixed_size=False ), preprocessing=self.config.preprocessing ) self.val_loader = DataLoader( self.val_dataset, batch_size=1, shuffle=False, num_workers=self.workers ) else: self.val_dataset = self.train_dataset self.val_loader = DataLoader( self.val_dataset, batch_size=1, shuffle=False, num_workers=self.workers ) self.train_loader = DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers ) def _setup_optimizer(self): """Configures model optimizer.""" optimizer_map = { 'adam': torch.optim.Adam, 'sgd': lambda params: torch.optim.SGD(params, momentum=0.9), 'rmsprop': torch.optim.RMSprop } optimizer_class = optimizer_map.get(self.optimizer_type.lower()) if not optimizer_class: raise ValueError(f"Unsupported optimizer: {self.optimizer_type}") self.optimizer = optimizer_class([{'params': self.config.model.parameters(), 'lr': self.learning_rate}]) def _initialize_tracking(self): """Sets up training progress tracking.""" timestamp = datetime.datetime.now().strftime("%m-%d-%Y_%H%M%S") self.output_dir = os.path.join( self.root_dir, f'model_outputs-{self.config.architecture}[{self.config.encoder}]-{timestamp}' ) os.makedirs(self.output_dir, exist_ok=True) self.writer = SummaryWriter(log_dir=self.output_dir) self.metrics = { 'best_epoch': self.best_epoch, 'best_epoch_iou': self.best_iou, 'last_epoch': 0, 'last_epoch_iou': 0.0, 'last_epoch_lr': self.learning_rate, 'step_schedule': self.step_schedule, 'decay_factor': self.decay_factor, 'target_class': self.target_class or 'overall' } def _calculate_step_schedule(self, epochs, steps): """Calculates learning rate step schedule.""" return list(map(int, np.linspace(0, epochs, steps + 2)[1:-1])) def train(self): """Executes training loop.""" model = self.config.model.to(self.device) if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) print(f'Using {torch.cuda.device_count()} GPUs') self._save_config() for epoch in range(self.current_epoch, self.epochs + 1): print(f'\nEpoch {epoch}/{self.epochs}') print(f'Learning rate: {self.optimizer.param_groups[0]["lr"]:.3e}') train_loss = self._train_epoch(model) val_loss, val_metrics = self._validate_epoch(model) self._update_tracking(epoch, train_loss, val_loss, val_metrics) self._adjust_learning_rate(epoch) self._save_checkpoints(model, epoch, val_metrics) print(f'\nTraining completed. Best {self.metrics["target_class"]} IoU: {self.best_iou:.3f}') return model, self.metrics def _train_epoch(self, model): """Executes single training epoch.""" model.train() total_loss = 0 sample_count = 0 for batch in tqdm(self.train_loader, desc='Training'): images, masks = [x.to(self.device) for x in batch] self.optimizer.zero_grad() outputs = model(images) loss = self.config.loss(outputs, masks) loss.backward() self.optimizer.step() total_loss += loss.item() * len(images) sample_count += len(images) return total_loss / sample_count def _validate_epoch(self, model): """Executes validation pass.""" model.eval() total_loss = 0 predictions = [] ground_truth = [] with torch.no_grad(): for batch in tqdm(self.val_loader, desc='Validation'): images, masks = [x.to(self.device) for x in batch] outputs = model(images) loss = self.config.loss(outputs, masks) total_loss += loss.item() if self.config.n_classes > 1: predictions.extend([p.cpu().argmax(dim=0) for p in outputs]) ground_truth.extend([m.cpu().argmax(dim=0) for m in masks]) else: predictions.extend([(torch.sigmoid(p) > 0.5).float().squeeze().cpu() for p in outputs]) ground_truth.extend([m.cpu().squeeze() for m in masks]) metrics = compute_mean_iou( predictions, ground_truth, num_classes=len(self.classes), ignore_index=255 ) return total_loss / len(self.val_loader), metrics def _update_tracking(self, epoch, train_loss, val_loss, val_metrics): """Updates training metrics and logging.""" mean_iou = val_metrics['mean_iou'] print(f"\nLosses - Train: {train_loss:.3f}, Val: {val_loss:.3f}") print(f"Mean IoU: {mean_iou:.3f}") self.writer.add_scalar('Loss/train', train_loss, epoch) self.writer.add_scalar('Loss/val', val_loss, epoch) self.writer.add_scalar('IoU/mean', mean_iou, epoch) for idx, iou in enumerate(val_metrics['per_category_iou']): print(f"{self.classes[idx]} IoU: {iou:.3f}") self.writer.add_scalar(f'IoU/{self.classes[idx]}', iou, epoch) def _adjust_learning_rate(self, epoch): """Adjusts learning rate according to schedule.""" if epoch in self.step_schedule: current_lr = self.optimizer.param_groups[0]['lr'] new_lr = current_lr * self.decay_factor for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr print(f'\nDecreased learning rate: {current_lr:.3e} -> {new_lr:.3e}') def _save_checkpoints(self, model, epoch, metrics): """Saves model checkpoints and metrics.""" epoch_iou = (metrics['mean_iou'] if self.target_class is None else metrics['per_category_iou'][self.classes.index(self.target_class)]) self.metrics.update({ 'last_epoch': epoch, 'last_epoch_iou': round(float(epoch_iou), 3), 'last_epoch_lr': self.optimizer.param_groups[0]['lr'] }) if epoch_iou > self.best_iou: self.best_iou = epoch_iou self.best_epoch = epoch self.metrics.update({ 'best_epoch': epoch, 'best_epoch_iou': round(float(epoch_iou), 3), 'overall_iou': round(float(metrics['mean_iou']), 3) }) torch.save(model, os.path.join(self.output_dir, 'best_model.pth')) print(f'New best model saved (IoU: {epoch_iou:.3f})') torch.save(model, os.path.join(self.output_dir, 'last_model.pth')) with open(os.path.join(self.output_dir, 'metrics.json'), 'w') as f: json.dump(self.metrics, f, indent=4) def _save_config(self): """Saves training configuration.""" config = { **self.config.config_data, 'train_size': self.train_size, 'val_size': self.val_size, 'epochs': self.epochs, 'batch_size': self.batch_size, 'optimizer': self.optimizer_type, 'workers': self.workers, 'target_class': self.target_class or 'overall' } with open(os.path.join(self.output_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=4) def _resume_training(self, resume_path): """Resumes training from checkpoint.""" if not os.path.exists(resume_path): raise FileNotFoundError(f"Resume path not found: {resume_path}") required_files = { 'model': 'last_model.pth', 'metrics': 'metrics.json', 'config': 'config.json' } paths = {k: os.path.join(resume_path, v) for k, v in required_files.items()} if not all(os.path.exists(p) for p in paths.values()): raise FileNotFoundError("Missing required checkpoint files") with open(paths['config']) as f: config = json.load(f) with open(paths['metrics']) as f: metrics = json.load(f) self.current_epoch = metrics['last_epoch'] + 1 self.best_iou = metrics['best_epoch_iou'] self.best_epoch = metrics['best_epoch'] self.learning_rate = metrics['last_epoch_lr'] print(f'Resuming training from epoch {self.current_epoch}')