import logging import os import random import shutil import sys import matplotlib import numpy as np import torch.distributed as dist import torch.utils.data from pytorch_lightning.loggers import TensorBoardLogger from torch import nn import utils from utils.hparams import hparams, set_hparams from utils.pl_utils import LatestModelCheckpoint, BaseTrainer, data_loader, DDP matplotlib.use('Agg') torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system')) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') class BaseTask(nn.Module): ''' Base class for training tasks. 1. *load_ckpt*: load checkpoint; 2. *training_step*: record and log the loss; 3. *optimizer_step*: run backwards step; 4. *start*: load training configs, backup code, log to tensorboard, start training; 5. *configure_ddp* and *init_ddp_connection*: start parallel training. Subclasses should define: 1. *build_model*, *build_optimizer*, *build_scheduler*: how to build the model, the optimizer and the training scheduler; 2. *_training_step*: one training step of the model; 3. *validation_end* and *_validation_end*: postprocess the validation output. ''' def __init__(self, *args, **kwargs): # dataset configs super(BaseTask, self).__init__(*args, **kwargs) self.current_epoch = 0 self.global_step = 0 self.loaded_optimizer_states_dict = {} self.trainer = None self.logger = None self.on_gpu = False self.use_dp = False self.use_ddp = False self.example_input_array = None self.max_tokens = hparams['max_tokens'] self.max_sentences = hparams['max_sentences'] self.max_eval_tokens = hparams['max_eval_tokens'] if self.max_eval_tokens == -1: hparams['max_eval_tokens'] = self.max_eval_tokens = self.max_tokens self.max_eval_sentences = hparams['max_eval_sentences'] if self.max_eval_sentences == -1: hparams['max_eval_sentences'] = self.max_eval_sentences = self.max_sentences self.model = None self.training_losses_meter = None ########### # Training, validation and testing ########### def build_model(self): raise NotImplementedError def load_ckpt(self, ckpt_base_dir, current_model_name=None, model_name='model', force=True, strict=True): # This function is updated on 2021.12.13 if current_model_name is None: current_model_name = model_name utils.load_ckpt(self.__getattr__(current_model_name), ckpt_base_dir, current_model_name, force, strict) def on_epoch_start(self): self.training_losses_meter = {'total_loss': utils.AvgrageMeter()} def _training_step(self, sample, batch_idx, optimizer_idx): """ :param sample: :param batch_idx: :return: total loss: torch.Tensor, loss_log: dict """ raise NotImplementedError def training_step(self, sample, batch_idx, optimizer_idx=-1): loss_ret = self._training_step(sample, batch_idx, optimizer_idx) self.opt_idx = optimizer_idx if loss_ret is None: return {'loss': None} total_loss, log_outputs = loss_ret log_outputs = utils.tensors_to_scalars(log_outputs) for k, v in log_outputs.items(): if k not in self.training_losses_meter: self.training_losses_meter[k] = utils.AvgrageMeter() if not np.isnan(v): self.training_losses_meter[k].update(v) self.training_losses_meter['total_loss'].update(total_loss.item()) try: log_outputs['lr'] = self.scheduler.get_lr() if isinstance(log_outputs['lr'], list): log_outputs['lr'] = log_outputs['lr'][0] except: pass # log_outputs['all_loss'] = total_loss.item() progress_bar_log = log_outputs tb_log = {f'tr/{k}': v for k, v in log_outputs.items()} return { 'loss': total_loss, 'progress_bar': progress_bar_log, 'log': tb_log } def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx): optimizer.step() optimizer.zero_grad() if self.scheduler is not None: self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) def on_epoch_end(self): loss_outputs = {k: round(v.avg, 4) for k, v in self.training_losses_meter.items()} print(f"\n==============\n " f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}" f"\n==============\n") def validation_step(self, sample, batch_idx): """ :param sample: :param batch_idx: :return: output: dict """ raise NotImplementedError def _validation_end(self, outputs): """ :param outputs: :return: loss_output: dict """ raise NotImplementedError def validation_end(self, outputs): loss_output = self._validation_end(outputs) print(f"\n==============\n " f"valid results: {loss_output}" f"\n==============\n") return { 'log': {f'val/{k}': v for k, v in loss_output.items()}, 'val_loss': loss_output['total_loss'] } def build_scheduler(self, optimizer): raise NotImplementedError def build_optimizer(self, model): raise NotImplementedError def configure_optimizers(self): optm = self.build_optimizer(self.model) self.scheduler = self.build_scheduler(optm) return [optm] def test_start(self): pass def test_step(self, sample, batch_idx): return self.validation_step(sample, batch_idx) def test_end(self, outputs): return self.validation_end(outputs) ########### # Running configuration ########### @classmethod def start(cls): set_hparams() os.environ['MASTER_PORT'] = str(random.randint(15000, 30000)) random.seed(hparams['seed']) np.random.seed(hparams['seed']) task = cls() work_dir = hparams['work_dir'] trainer = BaseTrainer(checkpoint_callback=LatestModelCheckpoint( filepath=work_dir, verbose=True, monitor='val_loss', mode='min', num_ckpt_keep=hparams['num_ckpt_keep'], save_best=hparams['save_best'], period=1 if hparams['save_ckpt'] else 100000 ), logger=TensorBoardLogger( save_dir=work_dir, name='lightning_logs', version='lastest' ), gradient_clip_val=hparams['clip_grad_norm'], val_check_interval=hparams['val_check_interval'], row_log_interval=hparams['log_interval'], max_updates=hparams['max_updates'], num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams[ 'validate'] else 10000, accumulate_grad_batches=hparams['accumulate_grad_batches']) if not hparams['infer']: # train # Copy spk_map.json to work dir spk_map = os.path.join(work_dir, 'spk_map.json') spk_map_orig = os.path.join(hparams['binary_data_dir'], 'spk_map.json') if not os.path.exists(spk_map) and os.path.exists(spk_map_orig): shutil.copy(spk_map_orig, spk_map) print(f"| Copied spk map to {spk_map}.") trainer.checkpoint_callback.task = task trainer.fit(task) else: trainer.test(task) @staticmethod def configure_ddp(model, device_ids): model = DDP( model, device_ids=device_ids, find_unused_parameters=True ) if dist.get_rank() != 0 and not hparams['debug']: sys.stdout = open(os.devnull, "w") sys.stderr = open(os.devnull, "w") random.seed(hparams['seed']) np.random.seed(hparams['seed']) return model @staticmethod def training_end(self, *args, **kwargs): return None def init_ddp_connection(self, proc_rank, world_size): set_hparams(print_hparams=False) # guarantees unique ports across jobs from same grid search default_port = 12910 # if user gave a port number, use that one instead try: default_port = os.environ['MASTER_PORT'] except Exception: os.environ['MASTER_PORT'] = str(default_port) # figure out the root node addr root_node = '127.0.0.2' root_node = self.trainer.resolve_root_node_address(root_node) os.environ['MASTER_ADDR'] = root_node dist.init_process_group('nccl', rank=proc_rank, world_size=world_size) @data_loader def train_dataloader(self): return None @data_loader def test_dataloader(self): return None @data_loader def val_dataloader(self): return None def on_load_checkpoint(self, checkpoint): pass def on_save_checkpoint(self, checkpoint): pass def on_sanity_check_start(self): pass def on_train_start(self): pass def on_train_end(self): pass def on_batch_start(self, batch): pass def on_batch_end(self): pass def on_pre_performance_check(self): pass def on_post_performance_check(self): pass def on_before_zero_grad(self, optimizer): pass def on_after_backward(self): pass @staticmethod def backward(loss, optimizer): loss.backward() def grad_norm(self, norm_type): results = {} total_norm = 0 for name, p in self.named_parameters(): if p.requires_grad: try: param_norm = p.grad.data.norm(norm_type) total_norm += param_norm ** norm_type norm = param_norm ** (1 / norm_type) grad = round(norm.data.cpu().numpy().flatten()[0], 3) results['grad_{}_norm_{}'.format(norm_type, name)] = grad except Exception: # this param had no grad pass total_norm = total_norm ** (1. / norm_type) grad = round(total_norm.data.cpu().numpy().flatten()[0], 3) results['grad_{}_norm_total'.format(norm_type)] = grad return results