|
from typing import Any, Callable, Dict |
|
import random |
|
import lightning.pytorch as pl |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
|
|
class AudioSep(pl.LightningModule): |
|
def __init__( |
|
self, |
|
ss_model: nn.Module, |
|
waveform_mixer, |
|
query_encoder, |
|
loss_function, |
|
optimizer_type: str, |
|
learning_rate: float, |
|
lr_lambda_func, |
|
use_text_ratio=1.0, |
|
): |
|
r"""Pytorch Lightning wrapper of PyTorch model, including forward, |
|
optimization of model, etc. |
|
|
|
Args: |
|
ss_model: nn.Module |
|
anchor_segment_detector: nn.Module |
|
loss_function: function or object |
|
learning_rate: float |
|
lr_lambda: function |
|
""" |
|
|
|
super().__init__() |
|
self.ss_model = ss_model |
|
self.waveform_mixer = waveform_mixer |
|
self.query_encoder = query_encoder |
|
self.query_encoder_type = self.query_encoder.encoder_type |
|
self.use_text_ratio = use_text_ratio |
|
self.loss_function = loss_function |
|
self.optimizer_type = optimizer_type |
|
self.learning_rate = learning_rate |
|
self.lr_lambda_func = lr_lambda_func |
|
|
|
|
|
def forward(self, x): |
|
pass |
|
|
|
def training_step(self, batch_data_dict, batch_idx): |
|
r"""Forward a mini-batch data to model, calculate loss function, and |
|
train for one step. A mini-batch data is evenly distributed to multiple |
|
devices (if there are) for parallel training. |
|
|
|
Args: |
|
batch_data_dict: e.g. |
|
'audio_text': { |
|
'text': ['a sound of dog', ...] |
|
'waveform': (batch_size, 1, samples) |
|
} |
|
batch_idx: int |
|
|
|
Returns: |
|
loss: float, loss function of this mini-batch |
|
""" |
|
|
|
random.seed(batch_idx) |
|
|
|
batch_audio_text_dict = batch_data_dict['audio_text'] |
|
|
|
batch_text = batch_audio_text_dict['text'] |
|
batch_audio = batch_audio_text_dict['waveform'] |
|
device = batch_audio.device |
|
|
|
mixtures, segments = self.waveform_mixer( |
|
waveforms=batch_audio |
|
) |
|
|
|
|
|
if self.query_encoder_type == 'CLAP': |
|
conditions = self.query_encoder.get_query_embed( |
|
modality='hybird', |
|
text=batch_text, |
|
audio=segments.squeeze(1), |
|
use_text_ratio=self.use_text_ratio, |
|
) |
|
|
|
input_dict = { |
|
'mixture': mixtures[:, None, :].squeeze(1), |
|
'condition': conditions, |
|
} |
|
|
|
target_dict = { |
|
'segment': segments.squeeze(1), |
|
} |
|
|
|
self.ss_model.train() |
|
sep_segment = self.ss_model(input_dict)['waveform'] |
|
sep_segment = sep_segment.squeeze() |
|
|
|
|
|
output_dict = { |
|
'segment': sep_segment, |
|
} |
|
|
|
|
|
loss = self.loss_function(output_dict, target_dict) |
|
|
|
self.log_dict({"train_loss": loss}) |
|
|
|
return loss |
|
|
|
def test_step(self, batch, batch_idx): |
|
pass |
|
|
|
def configure_optimizers(self): |
|
r"""Configure optimizer. |
|
""" |
|
|
|
if self.optimizer_type == "AdamW": |
|
optimizer = optim.AdamW( |
|
params=self.ss_model.parameters(), |
|
lr=self.learning_rate, |
|
betas=(0.9, 0.999), |
|
eps=1e-08, |
|
weight_decay=0.0, |
|
amsgrad=True, |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
scheduler = LambdaLR(optimizer, self.lr_lambda_func) |
|
|
|
output_dict = { |
|
"optimizer": optimizer, |
|
"lr_scheduler": { |
|
'scheduler': scheduler, |
|
'interval': 'step', |
|
'frequency': 1, |
|
} |
|
} |
|
|
|
return output_dict |
|
|
|
|
|
def get_model_class(model_type): |
|
if model_type == 'ResUNet30': |
|
from models.resunet import ResUNet30 |
|
return ResUNet30 |
|
|
|
else: |
|
raise NotImplementedError |
|
|