Spaces:
Runtime error
Runtime error
File size: 5,591 Bytes
5019931 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from typing import Any, Callable, Dict
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
class LitSourceSeparation(pl.LightningModule):
def __init__(
self,
batch_data_preprocessor,
model: nn.Module,
loss_function: Callable,
optimizer_type: str,
learning_rate: float,
lr_lambda: Callable,
):
r"""Pytorch Lightning wrapper of PyTorch model, including forward,
optimization of model, etc.
Args:
batch_data_preprocessor: object, used for preparing inputs and
targets for training. E.g., BasicBatchDataPreprocessor is used
for preparing data in dictionary into tensor.
model: nn.Module
loss_function: function
learning_rate: float
lr_lambda: function
"""
super().__init__()
self.batch_data_preprocessor = batch_data_preprocessor
self.model = model
self.optimizer_type = optimizer_type
self.loss_function = loss_function
self.learning_rate = learning_rate
self.lr_lambda = lr_lambda
def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float:
r"""Forward a mini-batch data to model, calculate loss function, and
train for one step. A mini-batch data is evenly distributed to multiple
devices (if there are) for parallel training.
Args:
batch_data_dict: e.g. {
'vocals': (batch_size, channels_num, segment_samples),
'accompaniment': (batch_size, channels_num, segment_samples),
'mixture': (batch_size, channels_num, segment_samples)
}
batch_idx: int
Returns:
loss: float, loss function of this mini-batch
"""
input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict)
# input_dict: {
# 'waveform': (batch_size, channels_num, segment_samples),
# (if_exist) 'condition': (batch_size, channels_num),
# }
# target_dict: {
# 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
# }
# Forward.
self.model.train()
output_dict = self.model(input_dict)
# output_dict: {
# 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
# }
outputs = output_dict['waveform']
# outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples)
# Calculate loss.
loss = self.loss_function(
output=outputs,
target=target_dict['waveform'],
mixture=input_dict['waveform'],
)
return loss
def configure_optimizers(self) -> Any:
r"""Configure optimizer."""
if self.optimizer_type == "Adam":
optimizer = optim.Adam(
self.model.parameters(),
lr=self.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0.0,
amsgrad=True,
)
elif self.optimizer_type == "AdamW":
optimizer = optim.AdamW(
self.model.parameters(),
lr=self.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0.0,
amsgrad=True,
)
else:
raise NotImplementedError
scheduler = {
'scheduler': LambdaLR(optimizer, self.lr_lambda),
'interval': 'step',
'frequency': 1,
}
return [optimizer], [scheduler]
def get_model_class(model_type):
r"""Get model.
Args:
model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN'
Returns:
nn.Module
"""
if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021':
from bytesep.models.resunet_ismir2021 import (
ResUNet143_DecouplePlusInplaceABN_ISMIR2021,
)
return ResUNet143_DecouplePlusInplaceABN_ISMIR2021
elif model_type == 'UNet':
from bytesep.models.unet import UNet
return UNet
elif model_type == 'UNetSubbandTime':
from bytesep.models.unet_subbandtime import UNetSubbandTime
return UNetSubbandTime
elif model_type == 'ResUNet143_Subbandtime':
from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime
return ResUNet143_Subbandtime
elif model_type == 'ResUNet143_DecouplePlus':
from bytesep.models.resunet import ResUNet143_DecouplePlus
return ResUNet143_DecouplePlus
elif model_type == 'ConditionalUNet':
from bytesep.models.conditional_unet import ConditionalUNet
return ConditionalUNet
elif model_type == 'LevelRNN':
from bytesep.models.levelrnn import LevelRNN
return LevelRNN
elif model_type == 'WavUNet':
from bytesep.models.wavunet import WavUNet
return WavUNet
elif model_type == 'WavUNetLevelRNN':
from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN
return WavUNetLevelRNN
elif model_type == 'TTnet':
from bytesep.models.ttnet import TTnet
return TTnet
elif model_type == 'TTnetNoTransformer':
from bytesep.models.ttnet_no_transformer import TTnetNoTransformer
return TTnetNoTransformer
else:
raise NotImplementedError
|