Music_Source_Separation / bytesep /models /lightning_modules.py
akhaliq3
spaces demo
5019931
raw
history blame
5.59 kB
from typing import Any, Callable, Dict
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
class LitSourceSeparation(pl.LightningModule):
def __init__(
self,
batch_data_preprocessor,
model: nn.Module,
loss_function: Callable,
optimizer_type: str,
learning_rate: float,
lr_lambda: Callable,
):
r"""Pytorch Lightning wrapper of PyTorch model, including forward,
optimization of model, etc.
Args:
batch_data_preprocessor: object, used for preparing inputs and
targets for training. E.g., BasicBatchDataPreprocessor is used
for preparing data in dictionary into tensor.
model: nn.Module
loss_function: function
learning_rate: float
lr_lambda: function
"""
super().__init__()
self.batch_data_preprocessor = batch_data_preprocessor
self.model = model
self.optimizer_type = optimizer_type
self.loss_function = loss_function
self.learning_rate = learning_rate
self.lr_lambda = lr_lambda
def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float:
r"""Forward a mini-batch data to model, calculate loss function, and
train for one step. A mini-batch data is evenly distributed to multiple
devices (if there are) for parallel training.
Args:
batch_data_dict: e.g. {
'vocals': (batch_size, channels_num, segment_samples),
'accompaniment': (batch_size, channels_num, segment_samples),
'mixture': (batch_size, channels_num, segment_samples)
}
batch_idx: int
Returns:
loss: float, loss function of this mini-batch
"""
input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict)
# input_dict: {
# 'waveform': (batch_size, channels_num, segment_samples),
# (if_exist) 'condition': (batch_size, channels_num),
# }
# target_dict: {
# 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
# }
# Forward.
self.model.train()
output_dict = self.model(input_dict)
# output_dict: {
# 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
# }
outputs = output_dict['waveform']
# outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples)
# Calculate loss.
loss = self.loss_function(
output=outputs,
target=target_dict['waveform'],
mixture=input_dict['waveform'],
)
return loss
def configure_optimizers(self) -> Any:
r"""Configure optimizer."""
if self.optimizer_type == "Adam":
optimizer = optim.Adam(
self.model.parameters(),
lr=self.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0.0,
amsgrad=True,
)
elif self.optimizer_type == "AdamW":
optimizer = optim.AdamW(
self.model.parameters(),
lr=self.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0.0,
amsgrad=True,
)
else:
raise NotImplementedError
scheduler = {
'scheduler': LambdaLR(optimizer, self.lr_lambda),
'interval': 'step',
'frequency': 1,
}
return [optimizer], [scheduler]
def get_model_class(model_type):
r"""Get model.
Args:
model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN'
Returns:
nn.Module
"""
if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021':
from bytesep.models.resunet_ismir2021 import (
ResUNet143_DecouplePlusInplaceABN_ISMIR2021,
)
return ResUNet143_DecouplePlusInplaceABN_ISMIR2021
elif model_type == 'UNet':
from bytesep.models.unet import UNet
return UNet
elif model_type == 'UNetSubbandTime':
from bytesep.models.unet_subbandtime import UNetSubbandTime
return UNetSubbandTime
elif model_type == 'ResUNet143_Subbandtime':
from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime
return ResUNet143_Subbandtime
elif model_type == 'ResUNet143_DecouplePlus':
from bytesep.models.resunet import ResUNet143_DecouplePlus
return ResUNet143_DecouplePlus
elif model_type == 'ConditionalUNet':
from bytesep.models.conditional_unet import ConditionalUNet
return ConditionalUNet
elif model_type == 'LevelRNN':
from bytesep.models.levelrnn import LevelRNN
return LevelRNN
elif model_type == 'WavUNet':
from bytesep.models.wavunet import WavUNet
return WavUNet
elif model_type == 'WavUNetLevelRNN':
from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN
return WavUNetLevelRNN
elif model_type == 'TTnet':
from bytesep.models.ttnet import TTnet
return TTnet
elif model_type == 'TTnetNoTransformer':
from bytesep.models.ttnet_no_transformer import TTnetNoTransformer
return TTnetNoTransformer
else:
raise NotImplementedError