|
import os |
|
import json |
|
import torch |
|
import wandb |
|
import datetime |
|
import numpy as np |
|
from tqdm import tqdm |
|
from torch.utils.data import DataLoader |
|
from torch.utils.tensorboard import SummaryWriter |
|
from segmentation_models_pytorch.base.modules import Activation |
|
|
|
from SemanticModel.data_loader import SegmentationDataset |
|
from SemanticModel.metrics import compute_mean_iou |
|
from SemanticModel.image_preprocessing import get_training_augmentations, get_validation_augmentations |
|
from SemanticModel.utilities import list_images, validate_dimensions |
|
|
|
class ModelTrainer: |
|
def __init__(self, model_config, root_dir, epochs=40, train_size=1024, |
|
val_size=None, workers=2, batch_size=2, learning_rate=1e-4, |
|
step_count=2, decay_factor=0.8, wandb_config=None, |
|
optimizer='rmsprop', target_class=None, resume_path=None): |
|
|
|
self.config = model_config |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.root_dir = root_dir |
|
self._initialize_training_params(epochs, train_size, val_size, workers, |
|
batch_size, learning_rate, step_count, |
|
decay_factor, optimizer, target_class) |
|
self._setup_directories() |
|
self._initialize_datasets() |
|
self._setup_optimizer() |
|
self._initialize_tracking() |
|
|
|
if resume_path: |
|
self._resume_training(resume_path) |
|
|
|
def _initialize_training_params(self, epochs, train_size, val_size, workers, |
|
batch_size, learning_rate, step_count, |
|
decay_factor, optimizer, target_class): |
|
self.epochs = epochs |
|
self.train_size = train_size |
|
self.val_size = val_size |
|
self.workers = workers |
|
self.batch_size = batch_size |
|
self.learning_rate = learning_rate |
|
self.step_schedule = self._calculate_step_schedule(epochs, step_count) |
|
self.decay_factor = decay_factor |
|
self.optimizer_type = optimizer |
|
self.target_class = target_class |
|
self.current_epoch = 1 |
|
self.best_iou = 0.0 |
|
self.best_epoch = 0 |
|
self.classes = ['background'] + self.config.classes if self.config.background_flag else self.config.classes |
|
|
|
def _setup_directories(self): |
|
"""Verifies and creates necessary directories.""" |
|
self.train_dir = os.path.join(self.root_dir, 'train') |
|
self.val_dir = os.path.join(self.root_dir, 'val') |
|
|
|
required_subdirs = ['Images', 'Masks'] |
|
for path in [self.train_dir] + ([self.val_dir] if os.path.exists(self.val_dir) else []): |
|
for subdir in required_subdirs: |
|
full_path = os.path.join(path, subdir) |
|
if not os.path.exists(full_path): |
|
raise FileNotFoundError(f"Missing directory: {full_path}") |
|
|
|
def _initialize_datasets(self): |
|
"""Sets up training and validation datasets.""" |
|
self.train_dataset = SegmentationDataset( |
|
self.train_dir, |
|
classes=self.classes, |
|
augmentation=get_training_augmentations(self.train_size, self.train_size), |
|
preprocessing=self.config.preprocessing |
|
) |
|
|
|
if os.path.exists(self.val_dir): |
|
self.val_dataset = SegmentationDataset( |
|
self.val_dir, |
|
classes=self.classes, |
|
augmentation=get_validation_augmentations( |
|
self.val_size or self.train_size, |
|
self.val_size or self.train_size, |
|
fixed_size=False |
|
), |
|
preprocessing=self.config.preprocessing |
|
) |
|
self.val_loader = DataLoader( |
|
self.val_dataset, |
|
batch_size=1, |
|
shuffle=False, |
|
num_workers=self.workers |
|
) |
|
else: |
|
self.val_dataset = self.train_dataset |
|
self.val_loader = DataLoader( |
|
self.val_dataset, |
|
batch_size=1, |
|
shuffle=False, |
|
num_workers=self.workers |
|
) |
|
|
|
self.train_loader = DataLoader( |
|
self.train_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=True, |
|
num_workers=self.workers |
|
) |
|
|
|
def _setup_optimizer(self): |
|
"""Configures model optimizer.""" |
|
optimizer_map = { |
|
'adam': torch.optim.Adam, |
|
'sgd': lambda params: torch.optim.SGD(params, momentum=0.9), |
|
'rmsprop': torch.optim.RMSprop |
|
} |
|
optimizer_class = optimizer_map.get(self.optimizer_type.lower()) |
|
if not optimizer_class: |
|
raise ValueError(f"Unsupported optimizer: {self.optimizer_type}") |
|
|
|
self.optimizer = optimizer_class([{'params': self.config.model.parameters(), |
|
'lr': self.learning_rate}]) |
|
|
|
def _initialize_tracking(self): |
|
"""Sets up training progress tracking.""" |
|
timestamp = datetime.datetime.now().strftime("%m-%d-%Y_%H%M%S") |
|
self.output_dir = os.path.join( |
|
self.root_dir, |
|
f'model_outputs-{self.config.architecture}[{self.config.encoder}]-{timestamp}' |
|
) |
|
os.makedirs(self.output_dir, exist_ok=True) |
|
|
|
self.writer = SummaryWriter(log_dir=self.output_dir) |
|
self.metrics = { |
|
'best_epoch': self.best_epoch, |
|
'best_epoch_iou': self.best_iou, |
|
'last_epoch': 0, |
|
'last_epoch_iou': 0.0, |
|
'last_epoch_lr': self.learning_rate, |
|
'step_schedule': self.step_schedule, |
|
'decay_factor': self.decay_factor, |
|
'target_class': self.target_class or 'overall' |
|
} |
|
|
|
def _calculate_step_schedule(self, epochs, steps): |
|
"""Calculates learning rate step schedule.""" |
|
return list(map(int, np.linspace(0, epochs, steps + 2)[1:-1])) |
|
|
|
def train(self): |
|
"""Executes training loop.""" |
|
model = self.config.model.to(self.device) |
|
if torch.cuda.device_count() > 1: |
|
model = torch.nn.DataParallel(model) |
|
print(f'Using {torch.cuda.device_count()} GPUs') |
|
|
|
self._save_config() |
|
|
|
for epoch in range(self.current_epoch, self.epochs + 1): |
|
print(f'\nEpoch {epoch}/{self.epochs}') |
|
print(f'Learning rate: {self.optimizer.param_groups[0]["lr"]:.3e}') |
|
|
|
train_loss = self._train_epoch(model) |
|
val_loss, val_metrics = self._validate_epoch(model) |
|
|
|
self._update_tracking(epoch, train_loss, val_loss, val_metrics) |
|
self._adjust_learning_rate(epoch) |
|
self._save_checkpoints(model, epoch, val_metrics) |
|
|
|
print(f'\nTraining completed. Best {self.metrics["target_class"]} IoU: {self.best_iou:.3f}') |
|
return model, self.metrics |
|
|
|
def _train_epoch(self, model): |
|
"""Executes single training epoch.""" |
|
model.train() |
|
total_loss = 0 |
|
sample_count = 0 |
|
|
|
for batch in tqdm(self.train_loader, desc='Training'): |
|
images, masks = [x.to(self.device) for x in batch] |
|
self.optimizer.zero_grad() |
|
|
|
outputs = model(images) |
|
loss = self.config.loss(outputs, masks) |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
total_loss += loss.item() * len(images) |
|
sample_count += len(images) |
|
|
|
return total_loss / sample_count |
|
|
|
def _validate_epoch(self, model): |
|
"""Executes validation pass.""" |
|
model.eval() |
|
total_loss = 0 |
|
predictions = [] |
|
ground_truth = [] |
|
|
|
with torch.no_grad(): |
|
for batch in tqdm(self.val_loader, desc='Validation'): |
|
images, masks = [x.to(self.device) for x in batch] |
|
outputs = model(images) |
|
loss = self.config.loss(outputs, masks) |
|
|
|
total_loss += loss.item() |
|
|
|
if self.config.n_classes > 1: |
|
predictions.extend([p.cpu().argmax(dim=0) for p in outputs]) |
|
ground_truth.extend([m.cpu().argmax(dim=0) for m in masks]) |
|
else: |
|
predictions.extend([(torch.sigmoid(p) > 0.5).float().squeeze().cpu() |
|
for p in outputs]) |
|
ground_truth.extend([m.cpu().squeeze() for m in masks]) |
|
|
|
metrics = compute_mean_iou( |
|
predictions, |
|
ground_truth, |
|
num_classes=len(self.classes), |
|
ignore_index=255 |
|
) |
|
|
|
return total_loss / len(self.val_loader), metrics |
|
|
|
def _update_tracking(self, epoch, train_loss, val_loss, val_metrics): |
|
"""Updates training metrics and logging.""" |
|
mean_iou = val_metrics['mean_iou'] |
|
print(f"\nLosses - Train: {train_loss:.3f}, Val: {val_loss:.3f}") |
|
print(f"Mean IoU: {mean_iou:.3f}") |
|
|
|
self.writer.add_scalar('Loss/train', train_loss, epoch) |
|
self.writer.add_scalar('Loss/val', val_loss, epoch) |
|
self.writer.add_scalar('IoU/mean', mean_iou, epoch) |
|
|
|
for idx, iou in enumerate(val_metrics['per_category_iou']): |
|
print(f"{self.classes[idx]} IoU: {iou:.3f}") |
|
self.writer.add_scalar(f'IoU/{self.classes[idx]}', iou, epoch) |
|
|
|
def _adjust_learning_rate(self, epoch): |
|
"""Adjusts learning rate according to schedule.""" |
|
if epoch in self.step_schedule: |
|
current_lr = self.optimizer.param_groups[0]['lr'] |
|
new_lr = current_lr * self.decay_factor |
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = new_lr |
|
print(f'\nDecreased learning rate: {current_lr:.3e} -> {new_lr:.3e}') |
|
|
|
def _save_checkpoints(self, model, epoch, metrics): |
|
"""Saves model checkpoints and metrics.""" |
|
epoch_iou = (metrics['mean_iou'] if self.target_class is None |
|
else metrics['per_category_iou'][self.classes.index(self.target_class)]) |
|
|
|
self.metrics.update({ |
|
'last_epoch': epoch, |
|
'last_epoch_iou': round(float(epoch_iou), 3), |
|
'last_epoch_lr': self.optimizer.param_groups[0]['lr'] |
|
}) |
|
|
|
if epoch_iou > self.best_iou: |
|
self.best_iou = epoch_iou |
|
self.best_epoch = epoch |
|
self.metrics.update({ |
|
'best_epoch': epoch, |
|
'best_epoch_iou': round(float(epoch_iou), 3), |
|
'overall_iou': round(float(metrics['mean_iou']), 3) |
|
}) |
|
torch.save(model, os.path.join(self.output_dir, 'best_model.pth')) |
|
print(f'New best model saved (IoU: {epoch_iou:.3f})') |
|
|
|
torch.save(model, os.path.join(self.output_dir, 'last_model.pth')) |
|
with open(os.path.join(self.output_dir, 'metrics.json'), 'w') as f: |
|
json.dump(self.metrics, f, indent=4) |
|
|
|
def _save_config(self): |
|
"""Saves training configuration.""" |
|
config = { |
|
**self.config.config_data, |
|
'train_size': self.train_size, |
|
'val_size': self.val_size, |
|
'epochs': self.epochs, |
|
'batch_size': self.batch_size, |
|
'optimizer': self.optimizer_type, |
|
'workers': self.workers, |
|
'target_class': self.target_class or 'overall' |
|
} |
|
|
|
with open(os.path.join(self.output_dir, 'config.json'), 'w') as f: |
|
json.dump(config, f, indent=4) |
|
|
|
def _resume_training(self, resume_path): |
|
"""Resumes training from checkpoint.""" |
|
if not os.path.exists(resume_path): |
|
raise FileNotFoundError(f"Resume path not found: {resume_path}") |
|
|
|
required_files = { |
|
'model': 'last_model.pth', |
|
'metrics': 'metrics.json', |
|
'config': 'config.json' |
|
} |
|
|
|
paths = {k: os.path.join(resume_path, v) for k, v in required_files.items()} |
|
if not all(os.path.exists(p) for p in paths.values()): |
|
raise FileNotFoundError("Missing required checkpoint files") |
|
|
|
with open(paths['config']) as f: |
|
config = json.load(f) |
|
with open(paths['metrics']) as f: |
|
metrics = json.load(f) |
|
|
|
self.current_epoch = metrics['last_epoch'] + 1 |
|
self.best_iou = metrics['best_epoch_iou'] |
|
self.best_epoch = metrics['best_epoch'] |
|
self.learning_rate = metrics['last_epoch_lr'] |
|
|
|
print(f'Resuming training from epoch {self.current_epoch}') |