Spaces:
Runtime error
Runtime error
from typing import Any, Callable, Dict | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.optim.lr_scheduler import LambdaLR | |
class LitSourceSeparation(pl.LightningModule): | |
def __init__( | |
self, | |
batch_data_preprocessor, | |
model: nn.Module, | |
loss_function: Callable, | |
optimizer_type: str, | |
learning_rate: float, | |
lr_lambda: Callable, | |
): | |
r"""Pytorch Lightning wrapper of PyTorch model, including forward, | |
optimization of model, etc. | |
Args: | |
batch_data_preprocessor: object, used for preparing inputs and | |
targets for training. E.g., BasicBatchDataPreprocessor is used | |
for preparing data in dictionary into tensor. | |
model: nn.Module | |
loss_function: function | |
learning_rate: float | |
lr_lambda: function | |
""" | |
super().__init__() | |
self.batch_data_preprocessor = batch_data_preprocessor | |
self.model = model | |
self.optimizer_type = optimizer_type | |
self.loss_function = loss_function | |
self.learning_rate = learning_rate | |
self.lr_lambda = lr_lambda | |
def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float: | |
r"""Forward a mini-batch data to model, calculate loss function, and | |
train for one step. A mini-batch data is evenly distributed to multiple | |
devices (if there are) for parallel training. | |
Args: | |
batch_data_dict: e.g. { | |
'vocals': (batch_size, channels_num, segment_samples), | |
'accompaniment': (batch_size, channels_num, segment_samples), | |
'mixture': (batch_size, channels_num, segment_samples) | |
} | |
batch_idx: int | |
Returns: | |
loss: float, loss function of this mini-batch | |
""" | |
input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict) | |
# input_dict: { | |
# 'waveform': (batch_size, channels_num, segment_samples), | |
# (if_exist) 'condition': (batch_size, channels_num), | |
# } | |
# target_dict: { | |
# 'waveform': (batch_size, target_sources_num * channels_num, segment_samples), | |
# } | |
# Forward. | |
self.model.train() | |
output_dict = self.model(input_dict) | |
# output_dict: { | |
# 'waveform': (batch_size, target_sources_num * channels_num, segment_samples), | |
# } | |
outputs = output_dict['waveform'] | |
# outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples) | |
# Calculate loss. | |
loss = self.loss_function( | |
output=outputs, | |
target=target_dict['waveform'], | |
mixture=input_dict['waveform'], | |
) | |
return loss | |
def configure_optimizers(self) -> Any: | |
r"""Configure optimizer.""" | |
if self.optimizer_type == "Adam": | |
optimizer = optim.Adam( | |
self.model.parameters(), | |
lr=self.learning_rate, | |
betas=(0.9, 0.999), | |
eps=1e-08, | |
weight_decay=0.0, | |
amsgrad=True, | |
) | |
elif self.optimizer_type == "AdamW": | |
optimizer = optim.AdamW( | |
self.model.parameters(), | |
lr=self.learning_rate, | |
betas=(0.9, 0.999), | |
eps=1e-08, | |
weight_decay=0.0, | |
amsgrad=True, | |
) | |
else: | |
raise NotImplementedError | |
scheduler = { | |
'scheduler': LambdaLR(optimizer, self.lr_lambda), | |
'interval': 'step', | |
'frequency': 1, | |
} | |
return [optimizer], [scheduler] | |
def get_model_class(model_type): | |
r"""Get model. | |
Args: | |
model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN' | |
Returns: | |
nn.Module | |
""" | |
if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021': | |
from bytesep.models.resunet_ismir2021 import ( | |
ResUNet143_DecouplePlusInplaceABN_ISMIR2021, | |
) | |
return ResUNet143_DecouplePlusInplaceABN_ISMIR2021 | |
elif model_type == 'UNet': | |
from bytesep.models.unet import UNet | |
return UNet | |
elif model_type == 'UNetSubbandTime': | |
from bytesep.models.unet_subbandtime import UNetSubbandTime | |
return UNetSubbandTime | |
elif model_type == 'ResUNet143_Subbandtime': | |
from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime | |
return ResUNet143_Subbandtime | |
elif model_type == 'ResUNet143_DecouplePlus': | |
from bytesep.models.resunet import ResUNet143_DecouplePlus | |
return ResUNet143_DecouplePlus | |
elif model_type == 'ConditionalUNet': | |
from bytesep.models.conditional_unet import ConditionalUNet | |
return ConditionalUNet | |
elif model_type == 'LevelRNN': | |
from bytesep.models.levelrnn import LevelRNN | |
return LevelRNN | |
elif model_type == 'WavUNet': | |
from bytesep.models.wavunet import WavUNet | |
return WavUNet | |
elif model_type == 'WavUNetLevelRNN': | |
from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN | |
return WavUNetLevelRNN | |
elif model_type == 'TTnet': | |
from bytesep.models.ttnet import TTnet | |
return TTnet | |
elif model_type == 'TTnetNoTransformer': | |
from bytesep.models.ttnet_no_transformer import TTnetNoTransformer | |
return TTnetNoTransformer | |
else: | |
raise NotImplementedError | |