SemanticSegmentationModel
/
semantic-segmentation
/SemanticModel
/.ipynb_checkpoints
/training-checkpoint.py
import os | |
import json | |
import torch | |
import wandb | |
import datetime | |
import numpy as np | |
from tqdm import tqdm | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
from segmentation_models_pytorch.base.modules import Activation | |
from SemanticModel.data_loader import SegmentationDataset | |
from SemanticModel.metrics import compute_mean_iou | |
from SemanticModel.image_preprocessing import get_training_augmentations, get_validation_augmentations | |
from SemanticModel.utilities import list_images, validate_dimensions | |
class ModelTrainer: | |
def __init__(self, model_config, root_dir, epochs=40, train_size=1024, | |
val_size=None, workers=2, batch_size=2, learning_rate=1e-4, | |
step_count=2, decay_factor=0.8, wandb_config=None, | |
optimizer='rmsprop', target_class=None, resume_path=None): | |
self.config = model_config | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.root_dir = root_dir | |
self._initialize_training_params(epochs, train_size, val_size, workers, | |
batch_size, learning_rate, step_count, | |
decay_factor, optimizer, target_class) | |
self._setup_directories() | |
self._initialize_datasets() | |
self._setup_optimizer() | |
self._initialize_tracking() | |
if resume_path: | |
self._resume_training(resume_path) | |
def _initialize_training_params(self, epochs, train_size, val_size, workers, | |
batch_size, learning_rate, step_count, | |
decay_factor, optimizer, target_class): | |
self.epochs = epochs | |
self.train_size = train_size | |
self.val_size = val_size | |
self.workers = workers | |
self.batch_size = batch_size | |
self.learning_rate = learning_rate | |
self.step_schedule = self._calculate_step_schedule(epochs, step_count) | |
self.decay_factor = decay_factor | |
self.optimizer_type = optimizer | |
self.target_class = target_class | |
self.current_epoch = 1 | |
self.best_iou = 0.0 | |
self.best_epoch = 0 | |
self.classes = ['background'] + self.config.classes if self.config.background_flag else self.config.classes | |
def _setup_directories(self): | |
"""Verifies and creates necessary directories.""" | |
self.train_dir = os.path.join(self.root_dir, 'train') | |
self.val_dir = os.path.join(self.root_dir, 'val') | |
required_subdirs = ['Images', 'Masks'] | |
for path in [self.train_dir] + ([self.val_dir] if os.path.exists(self.val_dir) else []): | |
for subdir in required_subdirs: | |
full_path = os.path.join(path, subdir) | |
if not os.path.exists(full_path): | |
raise FileNotFoundError(f"Missing directory: {full_path}") | |
def _initialize_datasets(self): | |
"""Sets up training and validation datasets.""" | |
self.train_dataset = SegmentationDataset( | |
self.train_dir, | |
classes=self.classes, | |
augmentation=get_training_augmentations(self.train_size, self.train_size), | |
preprocessing=self.config.preprocessing | |
) | |
if os.path.exists(self.val_dir): | |
self.val_dataset = SegmentationDataset( | |
self.val_dir, | |
classes=self.classes, | |
augmentation=get_validation_augmentations( | |
self.val_size or self.train_size, | |
self.val_size or self.train_size, | |
fixed_size=False | |
), | |
preprocessing=self.config.preprocessing | |
) | |
self.val_loader = DataLoader( | |
self.val_dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=self.workers | |
) | |
else: | |
self.val_dataset = self.train_dataset | |
self.val_loader = DataLoader( | |
self.val_dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=self.workers | |
) | |
self.train_loader = DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=True, | |
num_workers=self.workers | |
) | |
def _setup_optimizer(self): | |
"""Configures model optimizer.""" | |
optimizer_map = { | |
'adam': torch.optim.Adam, | |
'sgd': lambda params: torch.optim.SGD(params, momentum=0.9), | |
'rmsprop': torch.optim.RMSprop | |
} | |
optimizer_class = optimizer_map.get(self.optimizer_type.lower()) | |
if not optimizer_class: | |
raise ValueError(f"Unsupported optimizer: {self.optimizer_type}") | |
self.optimizer = optimizer_class([{'params': self.config.model.parameters(), | |
'lr': self.learning_rate}]) | |
def _initialize_tracking(self): | |
"""Sets up training progress tracking.""" | |
timestamp = datetime.datetime.now().strftime("%m-%d-%Y_%H%M%S") | |
self.output_dir = os.path.join( | |
self.root_dir, | |
f'model_outputs-{self.config.architecture}[{self.config.encoder}]-{timestamp}' | |
) | |
os.makedirs(self.output_dir, exist_ok=True) | |
self.writer = SummaryWriter(log_dir=self.output_dir) | |
self.metrics = { | |
'best_epoch': self.best_epoch, | |
'best_epoch_iou': self.best_iou, | |
'last_epoch': 0, | |
'last_epoch_iou': 0.0, | |
'last_epoch_lr': self.learning_rate, | |
'step_schedule': self.step_schedule, | |
'decay_factor': self.decay_factor, | |
'target_class': self.target_class or 'overall' | |
} | |
def _calculate_step_schedule(self, epochs, steps): | |
"""Calculates learning rate step schedule.""" | |
return list(map(int, np.linspace(0, epochs, steps + 2)[1:-1])) | |
def train(self): | |
"""Executes training loop.""" | |
model = self.config.model.to(self.device) | |
if torch.cuda.device_count() > 1: | |
model = torch.nn.DataParallel(model) | |
print(f'Using {torch.cuda.device_count()} GPUs') | |
self._save_config() | |
for epoch in range(self.current_epoch, self.epochs + 1): | |
print(f'\nEpoch {epoch}/{self.epochs}') | |
print(f'Learning rate: {self.optimizer.param_groups[0]["lr"]:.3e}') | |
train_loss = self._train_epoch(model) | |
val_loss, val_metrics = self._validate_epoch(model) | |
self._update_tracking(epoch, train_loss, val_loss, val_metrics) | |
self._adjust_learning_rate(epoch) | |
self._save_checkpoints(model, epoch, val_metrics) | |
print(f'\nTraining completed. Best {self.metrics["target_class"]} IoU: {self.best_iou:.3f}') | |
return model, self.metrics | |
def _train_epoch(self, model): | |
"""Executes single training epoch.""" | |
model.train() | |
total_loss = 0 | |
sample_count = 0 | |
for batch in tqdm(self.train_loader, desc='Training'): | |
images, masks = [x.to(self.device) for x in batch] | |
self.optimizer.zero_grad() | |
outputs = model(images) | |
loss = self.config.loss(outputs, masks) | |
loss.backward() | |
self.optimizer.step() | |
total_loss += loss.item() * len(images) | |
sample_count += len(images) | |
return total_loss / sample_count | |
def _validate_epoch(self, model): | |
"""Executes validation pass.""" | |
model.eval() | |
total_loss = 0 | |
predictions = [] | |
ground_truth = [] | |
with torch.no_grad(): | |
for batch in tqdm(self.val_loader, desc='Validation'): | |
images, masks = [x.to(self.device) for x in batch] | |
outputs = model(images) | |
loss = self.config.loss(outputs, masks) | |
total_loss += loss.item() | |
if self.config.n_classes > 1: | |
predictions.extend([p.cpu().argmax(dim=0) for p in outputs]) | |
ground_truth.extend([m.cpu().argmax(dim=0) for m in masks]) | |
else: | |
predictions.extend([(torch.sigmoid(p) > 0.5).float().squeeze().cpu() | |
for p in outputs]) | |
ground_truth.extend([m.cpu().squeeze() for m in masks]) | |
metrics = compute_mean_iou( | |
predictions, | |
ground_truth, | |
num_classes=len(self.classes), | |
ignore_index=255 | |
) | |
return total_loss / len(self.val_loader), metrics | |
def _update_tracking(self, epoch, train_loss, val_loss, val_metrics): | |
"""Updates training metrics and logging.""" | |
mean_iou = val_metrics['mean_iou'] | |
print(f"\nLosses - Train: {train_loss:.3f}, Val: {val_loss:.3f}") | |
print(f"Mean IoU: {mean_iou:.3f}") | |
self.writer.add_scalar('Loss/train', train_loss, epoch) | |
self.writer.add_scalar('Loss/val', val_loss, epoch) | |
self.writer.add_scalar('IoU/mean', mean_iou, epoch) | |
for idx, iou in enumerate(val_metrics['per_category_iou']): | |
print(f"{self.classes[idx]} IoU: {iou:.3f}") | |
self.writer.add_scalar(f'IoU/{self.classes[idx]}', iou, epoch) | |
def _adjust_learning_rate(self, epoch): | |
"""Adjusts learning rate according to schedule.""" | |
if epoch in self.step_schedule: | |
current_lr = self.optimizer.param_groups[0]['lr'] | |
new_lr = current_lr * self.decay_factor | |
for param_group in self.optimizer.param_groups: | |
param_group['lr'] = new_lr | |
print(f'\nDecreased learning rate: {current_lr:.3e} -> {new_lr:.3e}') | |
def _save_checkpoints(self, model, epoch, metrics): | |
"""Saves model checkpoints and metrics.""" | |
epoch_iou = (metrics['mean_iou'] if self.target_class is None | |
else metrics['per_category_iou'][self.classes.index(self.target_class)]) | |
self.metrics.update({ | |
'last_epoch': epoch, | |
'last_epoch_iou': round(float(epoch_iou), 3), | |
'last_epoch_lr': self.optimizer.param_groups[0]['lr'] | |
}) | |
if epoch_iou > self.best_iou: | |
self.best_iou = epoch_iou | |
self.best_epoch = epoch | |
self.metrics.update({ | |
'best_epoch': epoch, | |
'best_epoch_iou': round(float(epoch_iou), 3), | |
'overall_iou': round(float(metrics['mean_iou']), 3) | |
}) | |
torch.save(model, os.path.join(self.output_dir, 'best_model.pth')) | |
print(f'New best model saved (IoU: {epoch_iou:.3f})') | |
torch.save(model, os.path.join(self.output_dir, 'last_model.pth')) | |
with open(os.path.join(self.output_dir, 'metrics.json'), 'w') as f: | |
json.dump(self.metrics, f, indent=4) | |
def _save_config(self): | |
"""Saves training configuration.""" | |
config = { | |
**self.config.config_data, | |
'train_size': self.train_size, | |
'val_size': self.val_size, | |
'epochs': self.epochs, | |
'batch_size': self.batch_size, | |
'optimizer': self.optimizer_type, | |
'workers': self.workers, | |
'target_class': self.target_class or 'overall' | |
} | |
with open(os.path.join(self.output_dir, 'config.json'), 'w') as f: | |
json.dump(config, f, indent=4) | |
def _resume_training(self, resume_path): | |
"""Resumes training from checkpoint.""" | |
if not os.path.exists(resume_path): | |
raise FileNotFoundError(f"Resume path not found: {resume_path}") | |
required_files = { | |
'model': 'last_model.pth', | |
'metrics': 'metrics.json', | |
'config': 'config.json' | |
} | |
paths = {k: os.path.join(resume_path, v) for k, v in required_files.items()} | |
if not all(os.path.exists(p) for p in paths.values()): | |
raise FileNotFoundError("Missing required checkpoint files") | |
with open(paths['config']) as f: | |
config = json.load(f) | |
with open(paths['metrics']) as f: | |
metrics = json.load(f) | |
self.current_epoch = metrics['last_epoch'] + 1 | |
self.best_iou = metrics['best_epoch_iou'] | |
self.best_epoch = metrics['best_epoch'] | |
self.learning_rate = metrics['last_epoch_lr'] | |
print(f'Resuming training from epoch {self.current_epoch}') |