from typing import Any, Callable, Dict import random import lightning.pytorch as pl import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import LambdaLR class AudioSep(pl.LightningModule): def __init__( self, ss_model: nn.Module, waveform_mixer, query_encoder, loss_function, optimizer_type: str, learning_rate: float, lr_lambda_func, use_text_ratio=1.0, ): r"""Pytorch Lightning wrapper of PyTorch model, including forward, optimization of model, etc. Args: ss_model: nn.Module anchor_segment_detector: nn.Module loss_function: function or object learning_rate: float lr_lambda: function """ super().__init__() self.ss_model = ss_model self.waveform_mixer = waveform_mixer self.query_encoder = query_encoder self.query_encoder_type = self.query_encoder.encoder_type self.use_text_ratio = use_text_ratio self.loss_function = loss_function self.optimizer_type = optimizer_type self.learning_rate = learning_rate self.lr_lambda_func = lr_lambda_func def forward(self, x): pass def training_step(self, batch_data_dict, batch_idx): r"""Forward a mini-batch data to model, calculate loss function, and train for one step. A mini-batch data is evenly distributed to multiple devices (if there are) for parallel training. Args: batch_data_dict: e.g. 'audio_text': { 'text': ['a sound of dog', ...] 'waveform': (batch_size, 1, samples) } batch_idx: int Returns: loss: float, loss function of this mini-batch """ # [important] fix random seeds across devices random.seed(batch_idx) batch_audio_text_dict = batch_data_dict['audio_text'] batch_text = batch_audio_text_dict['text'] batch_audio = batch_audio_text_dict['waveform'] device = batch_audio.device mixtures, segments = self.waveform_mixer( waveforms=batch_audio ) # calculate text embed for audio-text data if self.query_encoder_type == 'CLAP': conditions = self.query_encoder.get_query_embed( modality='hybird', text=batch_text, audio=segments.squeeze(1), use_text_ratio=self.use_text_ratio, ) input_dict = { 'mixture': mixtures[:, None, :].squeeze(1), 'condition': conditions, } target_dict = { 'segment': segments.squeeze(1), } self.ss_model.train() sep_segment = self.ss_model(input_dict)['waveform'] sep_segment = sep_segment.squeeze() # (batch_size, 1, segment_samples) output_dict = { 'segment': sep_segment, } # Calculate loss. loss = self.loss_function(output_dict, target_dict) self.log_dict({"train_loss": loss}) return loss def test_step(self, batch, batch_idx): pass def configure_optimizers(self): r"""Configure optimizer. """ if self.optimizer_type == "AdamW": optimizer = optim.AdamW( params=self.ss_model.parameters(), lr=self.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=True, ) else: raise NotImplementedError scheduler = LambdaLR(optimizer, self.lr_lambda_func) output_dict = { "optimizer": optimizer, "lr_scheduler": { 'scheduler': scheduler, 'interval': 'step', 'frequency': 1, } } return output_dict def get_model_class(model_type): if model_type == 'ResUNet30': from models.resunet import ResUNet30 return ResUNet30 else: raise NotImplementedError