AudioSep / models /audiosep.py
badayvedat's picture
Initial commit
ae29df4
from typing import Any, Callable, Dict
import random
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
class AudioSep(pl.LightningModule):
def __init__(
self,
ss_model: nn.Module,
waveform_mixer,
query_encoder,
loss_function,
optimizer_type: str,
learning_rate: float,
lr_lambda_func,
use_text_ratio=1.0,
):
r"""Pytorch Lightning wrapper of PyTorch model, including forward,
optimization of model, etc.
Args:
ss_model: nn.Module
anchor_segment_detector: nn.Module
loss_function: function or object
learning_rate: float
lr_lambda: function
"""
super().__init__()
self.ss_model = ss_model
self.waveform_mixer = waveform_mixer
self.query_encoder = query_encoder
self.query_encoder_type = self.query_encoder.encoder_type
self.use_text_ratio = use_text_ratio
self.loss_function = loss_function
self.optimizer_type = optimizer_type
self.learning_rate = learning_rate
self.lr_lambda_func = lr_lambda_func
def forward(self, x):
pass
def training_step(self, batch_data_dict, batch_idx):
r"""Forward a mini-batch data to model, calculate loss function, and
train for one step. A mini-batch data is evenly distributed to multiple
devices (if there are) for parallel training.
Args:
batch_data_dict: e.g.
'audio_text': {
'text': ['a sound of dog', ...]
'waveform': (batch_size, 1, samples)
}
batch_idx: int
Returns:
loss: float, loss function of this mini-batch
"""
# [important] fix random seeds across devices
random.seed(batch_idx)
batch_audio_text_dict = batch_data_dict['audio_text']
batch_text = batch_audio_text_dict['text']
batch_audio = batch_audio_text_dict['waveform']
device = batch_audio.device
mixtures, segments = self.waveform_mixer(
waveforms=batch_audio
)
# calculate text embed for audio-text data
if self.query_encoder_type == 'CLAP':
conditions = self.query_encoder.get_query_embed(
modality='hybird',
text=batch_text,
audio=segments.squeeze(1),
use_text_ratio=self.use_text_ratio,
)
input_dict = {
'mixture': mixtures[:, None, :].squeeze(1),
'condition': conditions,
}
target_dict = {
'segment': segments.squeeze(1),
}
self.ss_model.train()
sep_segment = self.ss_model(input_dict)['waveform']
sep_segment = sep_segment.squeeze()
# (batch_size, 1, segment_samples)
output_dict = {
'segment': sep_segment,
}
# Calculate loss.
loss = self.loss_function(output_dict, target_dict)
self.log_dict({"train_loss": loss})
return loss
def test_step(self, batch, batch_idx):
pass
def configure_optimizers(self):
r"""Configure optimizer.
"""
if self.optimizer_type == "AdamW":
optimizer = optim.AdamW(
params=self.ss_model.parameters(),
lr=self.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0.0,
amsgrad=True,
)
else:
raise NotImplementedError
scheduler = LambdaLR(optimizer, self.lr_lambda_func)
output_dict = {
"optimizer": optimizer,
"lr_scheduler": {
'scheduler': scheduler,
'interval': 'step',
'frequency': 1,
}
}
return output_dict
def get_model_class(model_type):
if model_type == 'ResUNet30':
from models.resunet import ResUNet30
return ResUNet30
else:
raise NotImplementedError