|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from datetime import datetime |
|
import time |
|
import os |
|
import sys |
|
import importlib |
|
import json |
|
import random |
|
|
|
import logging |
|
import numpy as np |
|
import copy |
|
import contextlib |
|
import shutil |
|
from typing import Any, Callable, Union |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torch.optim.lr_scheduler as lr_scheduler |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
from mpi4py import MPI |
|
from infinibatch import iterators |
|
|
|
from .distributed_trainer import DistributedTrainer |
|
from .utils_trainer import UtilsTrainer |
|
from .utils.misc import * |
|
from .utils.serialization import JSONEncoder, filter_jsonable |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class DefaultTrainer(UtilsTrainer, DistributedTrainer): |
|
|
|
def __init__(self, opt): |
|
""" |
|
Set up the task the model is being trained for. |
|
""" |
|
super().__init__(opt) |
|
base_name = 'base_dir' |
|
base_path = os.path.join(self.opt['base_path'], '__init__.py') |
|
spec = importlib.util.spec_from_file_location(base_name, base_path) |
|
module = importlib.util.module_from_spec(spec) |
|
sys.modules[base_name] = module |
|
spec.loader.exec_module(module) |
|
logger.info(f"Imported {base_name} at base_path {self.opt['base_path']}") |
|
|
|
pipeline_module = importlib.import_module(f"base_dir.pipeline.{self.opt['PIPELINE']}") |
|
pipeline_class = getattr(pipeline_module, self.opt['PIPELINE']) |
|
logger.info(f"Pipeline for training: {self.opt['PIPELINE']}") |
|
self.pipeline = pipeline_class(self.opt) |
|
|
|
def eval(self, ): |
|
logger.info('-----------------------------------------------') |
|
logger.info("Evaluating model ... ") |
|
self.mode = "eval" |
|
|
|
|
|
self.raw_models = self.pipeline.initialize_model() |
|
self.model_names = self.raw_models.keys() |
|
|
|
|
|
for module_name in self.model_names: |
|
self.raw_models[module_name].to(self.opt['device']) |
|
|
|
|
|
if self.opt['WEIGHT'] and os.path.isfile(self.opt['RESUME_FROM']): |
|
model_path = self.opt['RESUME_FROM'] |
|
self.load_model(model_path) |
|
else: |
|
raise ValueError(f"Model not found: {model_path}") |
|
|
|
results = self._eval_on_set(self.save_folder) |
|
return results |
|
|
|
def _eval_on_set(self, save_folder): |
|
logger.info(f"Evaluation start ...") |
|
if self.opt['FP16']: |
|
from torch.cuda.amp import autocast |
|
with autocast(): |
|
results = self.pipeline.evaluate_model(self, save_folder) |
|
else: |
|
results = self.pipeline.evaluate_model(self, save_folder) |
|
if self.opt['rank'] == 0: |
|
logger.info(results) |
|
return results |
|
|
|
def compute_loss(self, forward_func, batch): |
|
|
|
def forward(func, trainer, batch): |
|
if self.opt['FP16']: |
|
from torch.cuda.amp import autocast |
|
with autocast(): |
|
loss = func(trainer, batch) |
|
else: |
|
loss = func(trainer, batch) |
|
return loss |
|
|
|
loss = forward(forward_func, self, batch) |
|
return loss |
|
|
|
def backward_loss(self, loss, model_names=['default']): |
|
|
|
def backward(loss_tensor): |
|
if self.opt['FP16']: |
|
self.grad_scaler.scale(loss_tensor).backward() |
|
else: |
|
loss_tensor.backward() |
|
|
|
if self.grad_acc_steps > 1: |
|
loss = loss / self.grad_acc_steps |
|
|
|
backward(loss) |
|
return loss |
|
|
|
def update_model(self, model_name='default'): |
|
if self.opt['FP16']: |
|
self.grad_scaler.unscale_(self.optimizers[model_name]) |
|
self.grad_scaler.step(self.optimizers[model_name]) |
|
else: |
|
self.optimizers[model_name].step() |
|
|
|
self.optimizers[model_name].zero_grad() |
|
self.train_params['optim_steps'][model_name] += 1 |
|
self.lr_schedulers[model_name].step() |
|
|
|
def train_step(self, batch): |
|
self.grad_acc_batches.append(batch) |
|
|
|
if self.is_gradient_accumulation_boundary(): |
|
|
|
for model_name in self.model_names: |
|
self.models[model_name].train() |
|
|
|
assert len(self.grad_acc_batches) == self.grad_acc_steps |
|
|
|
total_batch_sample = 0 |
|
for batch_index, batch in enumerate(self.grad_acc_batches): |
|
|
|
loss_info, sample_size_info, extra_info = \ |
|
self.pipeline.forward_step(self, |
|
batch, |
|
self.grad_acc_batches, |
|
batch_index, |
|
is_distributed=(self.opt['world_size'] > 1)) |
|
|
|
self.train_loss.update_iter(loss_info) |
|
total_batch_sample += sample_size_info['num_samples'] |
|
|
|
if self.opt['FP16']: |
|
|
|
self.grad_scaler.update() |
|
|
|
|
|
if self.opt['world_size'] > 1: |
|
total_batch_sample = torch.tensor(total_batch_sample).to(self.opt['device']) |
|
torch.distributed.all_reduce(total_batch_sample, torch.distributed.ReduceOp.SUM) |
|
total_batch_sample = total_batch_sample.item() |
|
|
|
self.train_params['total_batch_size'] += total_batch_sample |
|
self.grad_acc_batches = [] |
|
|
|
self.train_params['num_updates'] += 1 |
|
|
|
def init_train(self): |
|
self.mode = "train" |
|
logger.info('-------------------------------------------------------') |
|
logger.info("Training on rank: {}".format(self.opt['rank'])) |
|
|
|
self.raw_models = self.pipeline.initialize_model() |
|
self.model_names = list(self.raw_models.keys()) |
|
|
|
|
|
for module_name in self.model_names: |
|
self.raw_models[module_name].to(self.opt['device']) |
|
|
|
self.train_dataloaders = self.pipeline.get_dataloaders(self, 'train', is_evaluation=False) |
|
self.train_params = { |
|
"updates_per_epoch": len(self.train_dataloaders), |
|
"total_batch_size": 0, |
|
"num_updates": 0, |
|
"optim_steps": {module_name: 0 for module_name in self.model_names}, |
|
"start_epoch_idx": 0, |
|
"start_batch_idx": 0, |
|
"current_epoch_idx": 0, |
|
"current_batch_idx": 0, |
|
"resume_epoch_idx": 0, |
|
} |
|
|
|
self.train_loss = LossMeter() |
|
self.grad_acc_batches = [] |
|
|
|
if self.opt['CUDA']: |
|
torch.cuda.empty_cache() |
|
|
|
self.create_optimizer_and_scheduler() |
|
self.models = {model_name: self.raw_models[model_name] for model_name in self.model_names} |
|
self._initialize_ddp() |
|
|
|
if self.opt.get('WEIGHT', False): |
|
self.load_weight(self.opt['RESUME_FROM'], must_exist=True) |
|
if self.opt.get('RESUME', False): |
|
self.load_checkpoint(self.opt['RESUME_FROM'], must_exist=True) |
|
|
|
|
|
|
|
|
|
if self.opt['rank'] == 0: |
|
|
|
logger.info("***** Running training *****") |
|
logger.info(f" Num of GPUs = {self.opt['world_size']}") |
|
logger.info(f" Num Epochs = {self.opt['SOLVER']['MAX_NUM_EPOCHS']}") |
|
logger.info(f" Num of Mini Batches per Epoch = {self.train_params['updates_per_epoch']}") |
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {self.opt['SOLVER']['MAX_NUM_EPOCHS'] * self.train_params['updates_per_epoch']}") |
|
logger.info(f" Gradient Accumulation steps = {self.grad_acc_steps}") |
|
logger.info(f" Total optimization steps = {self.opt['SOLVER']['MAX_NUM_EPOCHS'] * self.train_params['updates_per_epoch'] // self.grad_acc_steps}") |
|
|
|
def train(self): |
|
""" |
|
Training |
|
""" |
|
self.init_train() |
|
current_optim_steps = self._get_and_validate_current_optim_steps() |
|
num_epochs = self.opt['SOLVER']['MAX_NUM_EPOCHS'] |
|
|
|
if self.opt.get('EVAL_AT_START', False): |
|
results = self._eval_on_set(self.save_folder) |
|
|
|
|
|
|
|
train_prev_logged_time = datetime.now() |
|
for epoch in range(self.train_params['start_epoch_idx'], num_epochs): |
|
self.train_params['current_epoch_idx'] = epoch |
|
logger.info(f"Start epoch: {epoch} training.") |
|
|
|
epoch_start_time = datetime.now() |
|
for batch_idx, batch in enumerate(self.train_dataloaders): |
|
if self.train_params['current_epoch_idx'] == self.train_params['start_epoch_idx']: |
|
if batch_idx < self.train_params['start_batch_idx']: |
|
continue |
|
|
|
self.train_params['current_batch_idx'] = batch_idx |
|
prev_optim_steps = current_optim_steps |
|
prev_total_batch_size = self.train_params['total_batch_size'] |
|
|
|
|
|
self.prev_optim_steps = prev_optim_steps |
|
self.train_step(batch) |
|
|
|
current_optim_steps = self._get_and_validate_current_optim_steps() |
|
|
|
|
|
if prev_optim_steps != current_optim_steps: |
|
log_first = self.opt.get("LOG_FIRST", 10) |
|
log_every = self.opt.get("LOG_EVERY", 100) |
|
if (current_optim_steps % log_every == 0) or (epoch == 0 and current_optim_steps <= log_first): |
|
|
|
last_lr = {} |
|
for module_name in self.model_names: |
|
last_lr[module_name] = self.lr_schedulers[module_name].get_last_lr()[0] |
|
|
|
train_time_delta = (datetime.now() - train_prev_logged_time).total_seconds() |
|
train_prev_logged_time = datetime.now() |
|
MB = 1024.0 * 1024.0 |
|
memory = torch.cuda.max_memory_allocated() / MB |
|
|
|
if self.opt['rank'] == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"epochs[{epoch:6}] optim steps[{current_optim_steps:.0f}] " |
|
f"learning rate[{', '.join([f'{key}: {val:.5e}' for key, val in last_lr.items()])}] " |
|
f"train loss[{', '.join([f'{key}: {obj.val:.5f}/{obj.avg:.5f}' for key, obj in self.train_loss.losses.items()])}] " |
|
|
|
f"items per batch[{self.train_params['total_batch_size'] - prev_total_batch_size}] " |
|
f"items per second[{(self.train_params['total_batch_size'] - prev_total_batch_size) / train_time_delta:.2f}] " |
|
f"total items[{self.train_params['total_batch_size']}] " |
|
f"mini batches[{self.train_params['num_updates']:6}] " |
|
f"memory[{memory:.0f}] " |
|
f"epoch remaining[{str((datetime.now() - epoch_start_time) / (batch_idx + 1) * (self.train_params['updates_per_epoch'] - batch_idx - 1)).split('.')[0]}]") |
|
|
|
|
|
if batch_idx + 1 == self.train_params['updates_per_epoch']: |
|
if self.opt.get('SAVE_CHECKPOINT', True): |
|
self.save_checkpoint(self.train_params['num_updates']) |
|
results = self._eval_on_set(self.save_folder) |
|
|
|
|
|
break |
|
|
|
logger.info(f"This epoch takes {datetime.now() - epoch_start_time}") |
|
logger.info(f"PROGRESS: {100.0 * (epoch + 1) / num_epochs:.2f}%") |
|
logger.info(f"Config files are at {self.opt['conf_files']}") |
|
|
|
|
|
|