from datetime import datetime |
import time |
import os |
import sys |
import importlib |
import json |
import random |
import logging |
import numpy as np |
import copy |
import contextlib |
import shutil |
from typing import Any, Callable, Union |
import torch |
import torch.nn as nn |
import torch.optim as optim |
import torch.optim.lr_scheduler as lr_scheduler |
from torch.utils.data import DataLoader |
from torch.utils.data.distributed import DistributedSampler |
from mpi4py import MPI |
from infinibatch import iterators |
from .distributed_trainer import DistributedTrainer |
from .utils_trainer import UtilsTrainer |
from .utils.misc import * |
from .utils.serialization import JSONEncoder, filter_jsonable |
logger = logging.getLogger(__name__) |
class DefaultTrainer(UtilsTrainer, DistributedTrainer): |
def __init__(self, opt): |
""" |
Set up the task the model is being trained for. |
""" |
super().__init__(opt) |
base_name = 'base_dir' |
base_path = os.path.join(self.opt['base_path'], '__init__.py') |
spec = importlib.util.spec_from_file_location(base_name, base_path) |
module = importlib.util.module_from_spec(spec) |
sys.modules[base_name] = module |
spec.loader.exec_module(module) |
logger.info(f"Imported {base_name} at base_path {self.opt['base_path']}") |
pipeline_module = importlib.import_module(f"base_dir.pipeline.{self.opt['PIPELINE']}") |
pipeline_class = getattr(pipeline_module, self.opt['PIPELINE']) |
logger.info(f"Pipeline for training: {self.opt['PIPELINE']}") |
self.pipeline = pipeline_class(self.opt) |
def eval(self, ): |
logger.info('-----------------------------------------------') |
logger.info("Evaluating model ... ") |
self.mode = "eval" |
self.raw_models = self.pipeline.initialize_model() |
self.model_names = self.raw_models.keys() |
for module_name in self.model_names: |
self.raw_models[module_name].to(self.opt['device']) |
if self.opt['WEIGHT'] and os.path.isfile(self.opt['RESUME_FROM']): |
model_path = self.opt['RESUME_FROM'] |
self.load_model(model_path) |
else: |
raise ValueError(f"Model not found: {model_path}") |
results = self._eval_on_set(self.save_folder) |
return results |
def _eval_on_set(self, save_folder): |
logger.info(f"Evaluation start ...") |
if self.opt['FP16']: |
from torch.cuda.amp import autocast |
with autocast(): |
results = self.pipeline.evaluate_model(self, save_folder) |
else: |
results = self.pipeline.evaluate_model(self, save_folder) |
if self.opt['rank'] == 0: |
logger.info(results) |
return results |
def compute_loss(self, forward_func, batch): |
def forward(func, trainer, batch): |
if self.opt['FP16']: |
from torch.cuda.amp import autocast |
with autocast(): |
loss = func(trainer, batch) |
else: |
loss = func(trainer, batch) |
return loss |
loss = forward(forward_func, self, batch) |
return loss |
def backward_loss(self, loss, model_names=['default']): |
def backward(loss_tensor): |
if self.opt['FP16']: |
self.grad_scaler.scale(loss_tensor).backward() |
else: |
loss_tensor.backward() |
if self.grad_acc_steps > 1: |
loss = loss / self.grad_acc_steps |
backward(loss) |
return loss |
def update_model(self, model_name='default'): |
if self.opt['FP16']: |
self.grad_scaler.unscale_(self.optimizers[model_name]) |
self.grad_scaler.step(self.optimizers[model_name]) |
else: |
self.optimizers[model_name].step() |
self.optimizers[model_name].zero_grad() |
self.train_params['optim_steps'][model_name] += 1 |
self.lr_schedulers[model_name].step() |
def train_step(self, batch): |
self.grad_acc_batches.append(batch) |
if self.is_gradient_accumulation_boundary(): |
for model_name in self.model_names: |
self.models[model_name].train() |
assert len(self.grad_acc_batches) == self.grad_acc_steps |
total_batch_sample = 0 |
for batch_index, batch in enumerate(self.grad_acc_batches): |
loss_info, sample_size_info, extra_info = \ |
self.pipeline.forward_step(self, |
batch, |
self.grad_acc_batches, |
batch_index, |
is_distributed=(self.opt['world_size'] > 1)) |
self.train_loss.update_iter(loss_info) |
total_batch_sample += sample_size_info['num_samples'] |
if self.opt['FP16']: |
self.grad_scaler.update() |
if self.opt['world_size'] > 1: |
total_batch_sample = torch.tensor(total_batch_sample).to(self.opt['device']) |
torch.distributed.all_reduce(total_batch_sample, torch.distributed.ReduceOp.SUM) |
total_batch_sample = total_batch_sample.item() |
self.train_params['total_batch_size'] += total_batch_sample |
self.grad_acc_batches = [] |
self.train_params['num_updates'] += 1 |
def init_train(self): |
self.mode = "train" |
logger.info('-------------------------------------------------------') |
logger.info("Training on rank: {}".format(self.opt['rank'])) |
self.raw_models = self.pipeline.initialize_model() |
self.model_names = list(self.raw_models.keys()) |
for module_name in self.model_names: |
self.raw_models[module_name].to(self.opt['device']) |
self.train_dataloaders = self.pipeline.get_dataloaders(self, 'train', is_evaluation=False) |
self.train_params = { |
"updates_per_epoch": len(self.train_dataloaders), |
"total_batch_size": 0, |
"num_updates": 0, |
"optim_steps": {module_name: 0 for module_name in self.model_names}, |
"start_epoch_idx": 0, |
"start_batch_idx": 0, |
"current_epoch_idx": 0, |
"current_batch_idx": 0, |
"resume_epoch_idx": 0, |
} |
self.train_loss = LossMeter() |
self.grad_acc_batches = [] |
if self.opt['CUDA']: |
torch.cuda.empty_cache() |
self.create_optimizer_and_scheduler() |
self.models = {model_name: self.raw_models[model_name] for model_name in self.model_names} |
self._initialize_ddp() |
if self.opt.get('WEIGHT', False): |
self.load_weight(self.opt['RESUME_FROM'], must_exist=True) |
if self.opt.get('RESUME', False): |
self.load_checkpoint(self.opt['RESUME_FROM'], must_exist=True) |
if self.opt['rank'] == 0: |
logger.info("***** Running training *****") |
logger.info(f" Num of GPUs = {self.opt['world_size']}") |
logger.info(f" Num Epochs = {self.opt['SOLVER']['MAX_NUM_EPOCHS']}") |
logger.info(f" Num of Mini Batches per Epoch = {self.train_params['updates_per_epoch']}") |
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {self.opt['SOLVER']['MAX_NUM_EPOCHS'] * self.train_params['updates_per_epoch']}") |
logger.info(f" Gradient Accumulation steps = {self.grad_acc_steps}") |
logger.info(f" Total optimization steps = {self.opt['SOLVER']['MAX_NUM_EPOCHS'] * self.train_params['updates_per_epoch'] // self.grad_acc_steps}") |
def train(self): |
""" |
Training |
""" |
self.init_train() |
current_optim_steps = self._get_and_validate_current_optim_steps() |
num_epochs = self.opt['SOLVER']['MAX_NUM_EPOCHS'] |
if self.opt.get('EVAL_AT_START', False): |
results = self._eval_on_set(self.save_folder) |
train_prev_logged_time = datetime.now() |
for epoch in range(self.train_params['start_epoch_idx'], num_epochs): |
self.train_params['current_epoch_idx'] = epoch |
logger.info(f"Start epoch: {epoch} training.") |
epoch_start_time = datetime.now() |
for batch_idx, batch in enumerate(self.train_dataloaders): |
if self.train_params['current_epoch_idx'] == self.train_params['start_epoch_idx']: |
if batch_idx < self.train_params['start_batch_idx']: |
continue |
self.train_params['current_batch_idx'] = batch_idx |
prev_optim_steps = current_optim_steps |
prev_total_batch_size = self.train_params['total_batch_size'] |
self.prev_optim_steps = prev_optim_steps |
self.train_step(batch) |
current_optim_steps = self._get_and_validate_current_optim_steps() |
if prev_optim_steps != current_optim_steps: |
log_first = self.opt.get("LOG_FIRST", 10) |
log_every = self.opt.get("LOG_EVERY", 100) |
if (current_optim_steps % log_every == 0) or (epoch == 0 and current_optim_steps <= log_first): |
last_lr = {} |
for module_name in self.model_names: |
last_lr[module_name] = self.lr_schedulers[module_name].get_last_lr()[0] |
train_time_delta = (datetime.now() - train_prev_logged_time).total_seconds() |
train_prev_logged_time = datetime.now() |
MB = 1024.0 * 1024.0 |
memory = torch.cuda.max_memory_allocated() / MB |
if self.opt['rank'] == 0: |
logger.info(f"epochs[{epoch:6}] optim steps[{current_optim_steps:.0f}] " |
f"learning rate[{', '.join([f'{key}: {val:.5e}' for key, val in last_lr.items()])}] " |
f"train loss[{', '.join([f'{key}: {obj.val:.5f}/{obj.avg:.5f}' for key, obj in self.train_loss.losses.items()])}] " |
f"items per batch[{self.train_params['total_batch_size'] - prev_total_batch_size}] " |
f"items per second[{(self.train_params['total_batch_size'] - prev_total_batch_size) / train_time_delta:.2f}] " |
f"total items[{self.train_params['total_batch_size']}] " |
f"mini batches[{self.train_params['num_updates']:6}] " |
f"memory[{memory:.0f}] " |
f"epoch remaining[{str((datetime.now() - epoch_start_time) / (batch_idx + 1) * (self.train_params['updates_per_epoch'] - batch_idx - 1)).split('.')[0]}]") |
if batch_idx + 1 == self.train_params['updates_per_epoch']: |
if self.opt.get('SAVE_CHECKPOINT', True): |
self.save_checkpoint(self.train_params['num_updates']) |
results = self._eval_on_set(self.save_folder) |
break |
logger.info(f"This epoch takes {datetime.now() - epoch_start_time}") |
logger.info(f"PROGRESS: {100.0 * (epoch + 1) / num_epochs:.2f}%") |
logger.info(f"Config files are at {self.opt['conf_files']}") |