from typing import Any, Callable, Dict import pytorch_lightning as pl import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import LambdaLR class LitSourceSeparation(pl.LightningModule): def __init__( self, batch_data_preprocessor, model: nn.Module, loss_function: Callable, optimizer_type: str, learning_rate: float, lr_lambda: Callable, ): r"""Pytorch Lightning wrapper of PyTorch model, including forward, optimization of model, etc. Args: batch_data_preprocessor: object, used for preparing inputs and targets for training. E.g., BasicBatchDataPreprocessor is used for preparing data in dictionary into tensor. model: nn.Module loss_function: function learning_rate: float lr_lambda: function """ super().__init__() self.batch_data_preprocessor = batch_data_preprocessor self.model = model self.optimizer_type = optimizer_type self.loss_function = loss_function self.learning_rate = learning_rate self.lr_lambda = lr_lambda def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float: r"""Forward a mini-batch data to model, calculate loss function, and train for one step. A mini-batch data is evenly distributed to multiple devices (if there are) for parallel training. Args: batch_data_dict: e.g. { 'vocals': (batch_size, channels_num, segment_samples), 'accompaniment': (batch_size, channels_num, segment_samples), 'mixture': (batch_size, channels_num, segment_samples) } batch_idx: int Returns: loss: float, loss function of this mini-batch """ input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict) # input_dict: { # 'waveform': (batch_size, channels_num, segment_samples), # (if_exist) 'condition': (batch_size, channels_num), # } # target_dict: { # 'waveform': (batch_size, target_sources_num * channels_num, segment_samples), # } # Forward. self.model.train() output_dict = self.model(input_dict) # output_dict: { # 'waveform': (batch_size, target_sources_num * channels_num, segment_samples), # } outputs = output_dict['waveform'] # outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples) # Calculate loss. loss = self.loss_function( output=outputs, target=target_dict['waveform'], mixture=input_dict['waveform'], ) return loss def configure_optimizers(self) -> Any: r"""Configure optimizer.""" if self.optimizer_type == "Adam": optimizer = optim.Adam( self.model.parameters(), lr=self.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=True, ) elif self.optimizer_type == "AdamW": optimizer = optim.AdamW( self.model.parameters(), lr=self.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=True, ) else: raise NotImplementedError scheduler = { 'scheduler': LambdaLR(optimizer, self.lr_lambda), 'interval': 'step', 'frequency': 1, } return [optimizer], [scheduler] def get_model_class(model_type): r"""Get model. Args: model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN' Returns: nn.Module """ if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021': from bytesep.models.resunet_ismir2021 import ( ResUNet143_DecouplePlusInplaceABN_ISMIR2021, ) return ResUNet143_DecouplePlusInplaceABN_ISMIR2021 elif model_type == 'UNet': from bytesep.models.unet import UNet return UNet elif model_type == 'UNetSubbandTime': from bytesep.models.unet_subbandtime import UNetSubbandTime return UNetSubbandTime elif model_type == 'ResUNet143_Subbandtime': from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime return ResUNet143_Subbandtime elif model_type == 'ResUNet143_DecouplePlus': from bytesep.models.resunet import ResUNet143_DecouplePlus return ResUNet143_DecouplePlus elif model_type == 'ConditionalUNet': from bytesep.models.conditional_unet import ConditionalUNet return ConditionalUNet elif model_type == 'LevelRNN': from bytesep.models.levelrnn import LevelRNN return LevelRNN elif model_type == 'WavUNet': from bytesep.models.wavunet import WavUNet return WavUNet elif model_type == 'WavUNetLevelRNN': from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN return WavUNetLevelRNN elif model_type == 'TTnet': from bytesep.models.ttnet import TTnet return TTnet elif model_type == 'TTnetNoTransformer': from bytesep.models.ttnet_no_transformer import TTnetNoTransformer return TTnetNoTransformer else: raise NotImplementedError