obichimav's picture
Upload 42 files
8e5d8c7 verified
import os
import json
import torch
import wandb
import datetime
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from segmentation_models_pytorch.base.modules import Activation
from SemanticModel.data_loader import SegmentationDataset
from SemanticModel.metrics import compute_mean_iou
from SemanticModel.image_preprocessing import get_training_augmentations, get_validation_augmentations
from SemanticModel.utilities import list_images, validate_dimensions
class ModelTrainer:
def __init__(self, model_config, root_dir, epochs=40, train_size=1024,
val_size=None, workers=2, batch_size=2, learning_rate=1e-4,
step_count=2, decay_factor=0.8, wandb_config=None,
optimizer='rmsprop', target_class=None, resume_path=None):
self.config = model_config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.root_dir = root_dir
self._initialize_training_params(epochs, train_size, val_size, workers,
batch_size, learning_rate, step_count,
decay_factor, optimizer, target_class)
self._setup_directories()
self._initialize_datasets()
self._setup_optimizer()
self._initialize_tracking()
if resume_path:
self._resume_training(resume_path)
def _initialize_training_params(self, epochs, train_size, val_size, workers,
batch_size, learning_rate, step_count,
decay_factor, optimizer, target_class):
self.epochs = epochs
self.train_size = train_size
self.val_size = val_size
self.workers = workers
self.batch_size = batch_size
self.learning_rate = learning_rate
self.step_schedule = self._calculate_step_schedule(epochs, step_count)
self.decay_factor = decay_factor
self.optimizer_type = optimizer
self.target_class = target_class
self.current_epoch = 1
self.best_iou = 0.0
self.best_epoch = 0
self.classes = ['background'] + self.config.classes if self.config.background_flag else self.config.classes
def _setup_directories(self):
"""Verifies and creates necessary directories."""
self.train_dir = os.path.join(self.root_dir, 'train')
self.val_dir = os.path.join(self.root_dir, 'val')
required_subdirs = ['Images', 'Masks']
for path in [self.train_dir] + ([self.val_dir] if os.path.exists(self.val_dir) else []):
for subdir in required_subdirs:
full_path = os.path.join(path, subdir)
if not os.path.exists(full_path):
raise FileNotFoundError(f"Missing directory: {full_path}")
def _initialize_datasets(self):
"""Sets up training and validation datasets."""
self.train_dataset = SegmentationDataset(
self.train_dir,
classes=self.classes,
augmentation=get_training_augmentations(self.train_size, self.train_size),
preprocessing=self.config.preprocessing
)
if os.path.exists(self.val_dir):
self.val_dataset = SegmentationDataset(
self.val_dir,
classes=self.classes,
augmentation=get_validation_augmentations(
self.val_size or self.train_size,
self.val_size or self.train_size,
fixed_size=False
),
preprocessing=self.config.preprocessing
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=1,
shuffle=False,
num_workers=self.workers
)
else:
self.val_dataset = self.train_dataset
self.val_loader = DataLoader(
self.val_dataset,
batch_size=1,
shuffle=False,
num_workers=self.workers
)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.workers
)
def _setup_optimizer(self):
"""Configures model optimizer."""
optimizer_map = {
'adam': torch.optim.Adam,
'sgd': lambda params: torch.optim.SGD(params, momentum=0.9),
'rmsprop': torch.optim.RMSprop
}
optimizer_class = optimizer_map.get(self.optimizer_type.lower())
if not optimizer_class:
raise ValueError(f"Unsupported optimizer: {self.optimizer_type}")
self.optimizer = optimizer_class([{'params': self.config.model.parameters(),
'lr': self.learning_rate}])
def _initialize_tracking(self):
"""Sets up training progress tracking."""
timestamp = datetime.datetime.now().strftime("%m-%d-%Y_%H%M%S")
self.output_dir = os.path.join(
self.root_dir,
f'model_outputs-{self.config.architecture}[{self.config.encoder}]-{timestamp}'
)
os.makedirs(self.output_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir=self.output_dir)
self.metrics = {
'best_epoch': self.best_epoch,
'best_epoch_iou': self.best_iou,
'last_epoch': 0,
'last_epoch_iou': 0.0,
'last_epoch_lr': self.learning_rate,
'step_schedule': self.step_schedule,
'decay_factor': self.decay_factor,
'target_class': self.target_class or 'overall'
}
def _calculate_step_schedule(self, epochs, steps):
"""Calculates learning rate step schedule."""
return list(map(int, np.linspace(0, epochs, steps + 2)[1:-1]))
def train(self):
"""Executes training loop."""
model = self.config.model.to(self.device)
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
print(f'Using {torch.cuda.device_count()} GPUs')
self._save_config()
for epoch in range(self.current_epoch, self.epochs + 1):
print(f'\nEpoch {epoch}/{self.epochs}')
print(f'Learning rate: {self.optimizer.param_groups[0]["lr"]:.3e}')
train_loss = self._train_epoch(model)
val_loss, val_metrics = self._validate_epoch(model)
self._update_tracking(epoch, train_loss, val_loss, val_metrics)
self._adjust_learning_rate(epoch)
self._save_checkpoints(model, epoch, val_metrics)
print(f'\nTraining completed. Best {self.metrics["target_class"]} IoU: {self.best_iou:.3f}')
return model, self.metrics
def _train_epoch(self, model):
"""Executes single training epoch."""
model.train()
total_loss = 0
sample_count = 0
for batch in tqdm(self.train_loader, desc='Training'):
images, masks = [x.to(self.device) for x in batch]
self.optimizer.zero_grad()
outputs = model(images)
loss = self.config.loss(outputs, masks)
loss.backward()
self.optimizer.step()
total_loss += loss.item() * len(images)
sample_count += len(images)
return total_loss / sample_count
def _validate_epoch(self, model):
"""Executes validation pass."""
model.eval()
total_loss = 0
predictions = []
ground_truth = []
with torch.no_grad():
for batch in tqdm(self.val_loader, desc='Validation'):
images, masks = [x.to(self.device) for x in batch]
outputs = model(images)
loss = self.config.loss(outputs, masks)
total_loss += loss.item()
if self.config.n_classes > 1:
predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
ground_truth.extend([m.cpu().argmax(dim=0) for m in masks])
else:
predictions.extend([(torch.sigmoid(p) > 0.5).float().squeeze().cpu()
for p in outputs])
ground_truth.extend([m.cpu().squeeze() for m in masks])
metrics = compute_mean_iou(
predictions,
ground_truth,
num_classes=len(self.classes),
ignore_index=255
)
return total_loss / len(self.val_loader), metrics
def _update_tracking(self, epoch, train_loss, val_loss, val_metrics):
"""Updates training metrics and logging."""
mean_iou = val_metrics['mean_iou']
print(f"\nLosses - Train: {train_loss:.3f}, Val: {val_loss:.3f}")
print(f"Mean IoU: {mean_iou:.3f}")
self.writer.add_scalar('Loss/train', train_loss, epoch)
self.writer.add_scalar('Loss/val', val_loss, epoch)
self.writer.add_scalar('IoU/mean', mean_iou, epoch)
for idx, iou in enumerate(val_metrics['per_category_iou']):
print(f"{self.classes[idx]} IoU: {iou:.3f}")
self.writer.add_scalar(f'IoU/{self.classes[idx]}', iou, epoch)
def _adjust_learning_rate(self, epoch):
"""Adjusts learning rate according to schedule."""
if epoch in self.step_schedule:
current_lr = self.optimizer.param_groups[0]['lr']
new_lr = current_lr * self.decay_factor
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
print(f'\nDecreased learning rate: {current_lr:.3e} -> {new_lr:.3e}')
def _save_checkpoints(self, model, epoch, metrics):
"""Saves model checkpoints and metrics."""
epoch_iou = (metrics['mean_iou'] if self.target_class is None
else metrics['per_category_iou'][self.classes.index(self.target_class)])
self.metrics.update({
'last_epoch': epoch,
'last_epoch_iou': round(float(epoch_iou), 3),
'last_epoch_lr': self.optimizer.param_groups[0]['lr']
})
if epoch_iou > self.best_iou:
self.best_iou = epoch_iou
self.best_epoch = epoch
self.metrics.update({
'best_epoch': epoch,
'best_epoch_iou': round(float(epoch_iou), 3),
'overall_iou': round(float(metrics['mean_iou']), 3)
})
torch.save(model, os.path.join(self.output_dir, 'best_model.pth'))
print(f'New best model saved (IoU: {epoch_iou:.3f})')
torch.save(model, os.path.join(self.output_dir, 'last_model.pth'))
with open(os.path.join(self.output_dir, 'metrics.json'), 'w') as f:
json.dump(self.metrics, f, indent=4)
def _save_config(self):
"""Saves training configuration."""
config = {
**self.config.config_data,
'train_size': self.train_size,
'val_size': self.val_size,
'epochs': self.epochs,
'batch_size': self.batch_size,
'optimizer': self.optimizer_type,
'workers': self.workers,
'target_class': self.target_class or 'overall'
}
with open(os.path.join(self.output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=4)
def _resume_training(self, resume_path):
"""Resumes training from checkpoint."""
if not os.path.exists(resume_path):
raise FileNotFoundError(f"Resume path not found: {resume_path}")
required_files = {
'model': 'last_model.pth',
'metrics': 'metrics.json',
'config': 'config.json'
}
paths = {k: os.path.join(resume_path, v) for k, v in required_files.items()}
if not all(os.path.exists(p) for p in paths.values()):
raise FileNotFoundError("Missing required checkpoint files")
with open(paths['config']) as f:
config = json.load(f)
with open(paths['metrics']) as f:
metrics = json.load(f)
self.current_epoch = metrics['last_epoch'] + 1
self.best_iou = metrics['best_epoch_iou']
self.best_epoch = metrics['best_epoch']
self.learning_rate = metrics['last_epoch_lr']
print(f'Resuming training from epoch {self.current_epoch}')