import os import sys import time import math import yaml import torch import random import numpy as np from tqdm import tqdm from pprint import pprint from torch.utils import data import torch.distributed as dist from torch.cuda.amp import autocast, GradScaler from tensorboardX import SummaryWriter from dataset import load_dataset from loss import get_loss from model import load_model from optimizer import get_optimizer from scheduler import get_scheduler from trainer import AbstractTrainer, LEGAL_METRIC from trainer.utils import exp_recons_loss, MLLoss, reduce_tensor, center_print from trainer.utils import MODELS_PATH, AccMeter, AUCMeter, AverageMeter, Logger, Timer class ExpMultiGpuTrainer(AbstractTrainer): def __init__(self, config, stage="Train"): super(ExpMultiGpuTrainer, self).__init__(config, stage) np.random.seed(2021) def _mprint(self, content=""): if self.local_rank == 0: print(content) def _initiated_settings(self, model_cfg=None, data_cfg=None, config_cfg=None): self.local_rank = config_cfg["local_rank"] def _train_settings(self, model_cfg, data_cfg, config_cfg): # debug mode: no log dir, no train_val operation. self.debug = config_cfg["debug"] self._mprint(f"Using debug mode: {self.debug}.") self._mprint("*" * 20) self.eval_metric = config_cfg["metric"] if self.eval_metric not in LEGAL_METRIC: raise ValueError(f"Evaluation metric must be in {LEGAL_METRIC}, but found " f"{self.eval_metric}.") if self.eval_metric == LEGAL_METRIC[-1]: self.best_metric = 1.0e8 # distribution dist.init_process_group(config_cfg["distribute"]["backend"]) # load training dataset train_dataset = data_cfg["file"] branch = data_cfg["train_branch"] name = data_cfg["name"] with open(train_dataset, "r") as f: options = yaml.load(f, Loader=yaml.FullLoader) train_options = options[branch] self.train_set = load_dataset(name)(train_options) # define training sampler self.train_sampler = data.distributed.DistributedSampler(self.train_set) # wrapped with data loader self.train_loader = data.DataLoader(self.train_set, shuffle=False, sampler=self.train_sampler, num_workers=data_cfg.get("num_workers", 4), batch_size=data_cfg["train_batch_size"]) if self.local_rank == 0: # load validation dataset val_options = options[data_cfg["val_branch"]] self.val_set = load_dataset(name)(val_options) # wrapped with data loader self.val_loader = data.DataLoader(self.val_set, shuffle=True, num_workers=data_cfg.get("num_workers", 4), batch_size=data_cfg["val_batch_size"]) self.resume = config_cfg.get("resume", False) if not self.debug: time_format = "%Y-%m-%d...%H.%M.%S" run_id = time.strftime(time_format, time.localtime(time.time())) self.run_id = config_cfg.get("id", run_id) self.dir = os.path.join("runs", self.model_name, self.run_id) if self.local_rank == 0: if not self.resume: if os.path.exists(self.dir): raise ValueError("Error: given id '%s' already exists." % self.run_id) os.makedirs(self.dir, exist_ok=True) print(f"Writing config file to file directory: {self.dir}.") yaml.dump({"config": self.config, "train_data": train_options, "val_data": val_options}, open(os.path.join(self.dir, 'train_config.yml'), 'w')) # copy the script for the training model model_file = MODELS_PATH[self.model_name] os.system("cp " + model_file + " " + self.dir) else: print(f"Resuming the history in file directory: {self.dir}.") print(f"Logging directory: {self.dir}.") # redirect the std out stream sys.stdout = Logger(os.path.join(self.dir, 'records.txt')) center_print('Train configurations begins.') pprint(self.config) pprint(train_options) pprint(val_options) center_print('Train configurations ends.') # load model self.num_classes = model_cfg["num_classes"] self.device = "cuda:" + str(self.local_rank) self.model = load_model(self.model_name)(**model_cfg) self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model).to(self.device) self._mprint(f"Using SyncBatchNorm.") self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[self.local_rank], find_unused_parameters=True) # load optimizer optim_cfg = config_cfg.get("optimizer", None) optim_name = optim_cfg.pop("name") self.optimizer = get_optimizer(optim_name)(self.model.parameters(), **optim_cfg) # load scheduler self.scheduler = get_scheduler(self.optimizer, config_cfg.get("scheduler", None)) # load loss self.loss_criterion = get_loss(config_cfg.get("loss", None), device=self.device) # total number of steps (or epoch) to train self.num_steps = train_options["num_steps"] self.num_epoch = math.ceil(self.num_steps / len(self.train_loader)) # the number of steps to write down a log self.log_steps = train_options["log_steps"] # the number of steps to validate on val dataset once self.val_steps = train_options["val_steps"] # balance coefficients self.lambda_1 = config_cfg["lambda_1"] self.lambda_2 = config_cfg["lambda_2"] self.warmup_step = config_cfg.get('warmup_step', 0) self.contra_loss = MLLoss() self.acc_meter = AccMeter() self.loss_meter = AverageMeter() self.recons_loss_meter = AverageMeter() self.contra_loss_meter = AverageMeter() if self.resume and self.local_rank == 0: self._load_ckpt(best=config_cfg.get("resume_best", False), train=True) def _test_settings(self, model_cfg, data_cfg, config_cfg): # Not used. raise NotImplementedError("The function is not intended to be used here.") def _load_ckpt(self, best=False, train=False): # Not used. raise NotImplementedError("The function is not intended to be used here.") def _save_ckpt(self, step, best=False): save_dir = os.path.join(self.dir, f"best_model_{step}.bin" if best else "latest_model.bin") torch.save({ "step": step, "best_step": self.best_step, "best_metric": self.best_metric, "eval_metric": self.eval_metric, "model": self.model.module.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict(), }, save_dir) def train(self): try: timer = Timer() grad_scalar = GradScaler(2 ** 10) if self.local_rank == 0: writer = None if self.debug else SummaryWriter(log_dir=self.dir) center_print("Training begins......") else: writer = None start_epoch = self.start_step // len(self.train_loader) + 1 for epoch_idx in range(start_epoch, self.num_epoch + 1): # set sampler self.train_sampler.set_epoch(epoch_idx) # reset meter self.acc_meter.reset() self.loss_meter.reset() self.recons_loss_meter.reset() self.contra_loss_meter.reset() self.optimizer.step() train_generator = enumerate(self.train_loader, 1) # wrap train generator with tqdm for process 0 if self.local_rank == 0: train_generator = tqdm(train_generator, position=0, leave=True) for batch_idx, train_data in train_generator: global_step = (epoch_idx - 1) * len(self.train_loader) + batch_idx self.model.train() I, Y = train_data I = self.train_loader.dataset.load_item(I) in_I, Y = self.to_device((I, Y)) # warm-up lr if self.warmup_step != 0 and global_step <= self.warmup_step: lr = self.config['config']['optimizer']['lr'] * float(global_step) / self.warmup_step for param_group in self.optimizer.param_groups: param_group['lr'] = lr self.optimizer.zero_grad() with autocast(): Y_pre = self.model(in_I) # for BCE Setting: if self.num_classes == 1: Y_pre = Y_pre.squeeze() loss = self.loss_criterion(Y_pre, Y.float()) Y_pre = torch.sigmoid(Y_pre) else: loss = self.loss_criterion(Y_pre, Y) # flood loss = (loss - 0.04).abs() + 0.04 recons_loss = exp_recons_loss(self.model.module.loss_inputs['recons'], (in_I, Y)) contra_loss = self.contra_loss(self.model.module.loss_inputs['contra'], Y) loss += self.lambda_1 * recons_loss + self.lambda_2 * contra_loss grad_scalar.scale(loss).backward() grad_scalar.step(self.optimizer) grad_scalar.update() if self.warmup_step == 0 or global_step > self.warmup_step: self.scheduler.step() self.acc_meter.update(Y_pre, Y, self.num_classes == 1) self.loss_meter.update(reduce_tensor(loss).item()) self.recons_loss_meter.update(reduce_tensor(recons_loss).item()) self.contra_loss_meter.update(reduce_tensor(contra_loss).item()) iter_acc = reduce_tensor(self.acc_meter.mean_acc()).item() if self.local_rank == 0: if global_step % self.log_steps == 0 and writer is not None: writer.add_scalar("train/Acc", iter_acc, global_step) writer.add_scalar("train/Loss", self.loss_meter.avg, global_step) writer.add_scalar("train/Recons_Loss", self.recons_loss_meter.avg if self.lambda_1 != 0 else 0., global_step) writer.add_scalar("train/Contra_Loss", self.contra_loss_meter.avg, global_step) writer.add_scalar("train/LR", self.scheduler.get_last_lr()[0], global_step) # log training step train_generator.set_description( "Train Epoch %d (%d/%d), Global Step %d, Loss %.4f, Recons %.4f, con %.4f, " "ACC %.4f, LR %.6f" % ( epoch_idx, batch_idx, len(self.train_loader), global_step, self.loss_meter.avg, self.recons_loss_meter.avg, self.contra_loss_meter.avg, iter_acc, self.scheduler.get_last_lr()[0]) ) # validating process if global_step % self.val_steps == 0 and not self.debug: print() self.validate(epoch_idx, global_step, timer, writer) # when num_steps has been set and the training process will # be stopped earlier than the specified num_epochs, then stop. if self.num_steps is not None and global_step == self.num_steps: if writer is not None: writer.close() if self.local_rank == 0: print() center_print("Training process ends.") dist.destroy_process_group() return # close the tqdm bar when one epoch ends if self.local_rank == 0: train_generator.close() print() # training ends with integer epochs if self.local_rank == 0: if writer is not None: writer.close() center_print("Training process ends.") dist.destroy_process_group() except Exception as e: dist.destroy_process_group() raise e def validate(self, epoch, step, timer, writer): v_idx = random.randint(1, len(self.val_loader) + 1) categories = self.val_loader.dataset.categories self.model.eval() with torch.no_grad(): acc = AccMeter() auc = AUCMeter() loss_meter = AverageMeter() cur_acc = 0.0 # Higher is better cur_auc = 0.0 # Higher is better cur_loss = 1e8 # Lower is better val_generator = tqdm(enumerate(self.val_loader, 1), position=0, leave=True) for val_idx, val_data in val_generator: I, Y = val_data I = self.val_loader.dataset.load_item(I) in_I, Y = self.to_device((I, Y)) Y_pre = self.model(in_I) # for BCE Setting: if self.num_classes == 1: Y_pre = Y_pre.squeeze() loss = self.loss_criterion(Y_pre, Y.float()) Y_pre = torch.sigmoid(Y_pre) else: loss = self.loss_criterion(Y_pre, Y) acc.update(Y_pre, Y, self.num_classes == 1) auc.update(Y_pre, Y, self.num_classes == 1) loss_meter.update(loss.item()) cur_acc = acc.mean_acc() cur_loss = loss_meter.avg val_generator.set_description( "Eval Epoch %d (%d/%d), Global Step %d, Loss %.4f, ACC %.4f" % ( epoch, val_idx, len(self.val_loader), step, cur_loss, cur_acc) ) if val_idx == v_idx or val_idx == 1: sample_recons = list() for _ in self.model.module.loss_inputs['recons']: sample_recons.append(_[:4].to("cpu")) # show images images = I[:4] images = torch.cat([images, *sample_recons], dim=0) pred = Y_pre[:4] gt = Y[:4] figure = self.plot_figure(images, pred, gt, 4, categories, show=False) cur_auc = auc.mean_auc() print("Eval Epoch %d, Loss %.4f, ACC %.4f, AUC %.4f" % (epoch, cur_loss, cur_acc, cur_auc)) if writer is not None: writer.add_scalar("val/Loss", cur_loss, step) writer.add_scalar("val/Acc", cur_acc, step) writer.add_scalar("val/AUC", cur_auc, step) writer.add_figure("val/Figures", figure, step) # record the best acc and the corresponding step if self.eval_metric == 'Acc' and cur_acc >= self.best_metric: self.best_metric = cur_acc self.best_step = step self._save_ckpt(step, best=True) elif self.eval_metric == 'AUC' and cur_auc >= self.best_metric: self.best_metric = cur_auc self.best_step = step self._save_ckpt(step, best=True) elif self.eval_metric == 'LogLoss' and cur_loss <= self.best_metric: self.best_metric = cur_loss self.best_step = step self._save_ckpt(step, best=True) print("Best Step %d, Best %s %.4f, Running Time: %s, Estimated Time: %s" % ( self.best_step, self.eval_metric, self.best_metric, timer.measure(), timer.measure(step / self.num_steps) )) self._save_ckpt(step, best=False) def test(self): # Not used. raise NotImplementedError("The function is not intended to be used here.")