import logging import os import time from typing import Dict, List, NoReturn import librosa import musdb import museval import numpy as np import pytorch_lightning as pl import torch.nn as nn from pytorch_lightning.utilities import rank_zero_only from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback from bytesep.dataset_creation.pack_audios_to_hdf5s.musdb18 import preprocess_audio from bytesep.inference import Separator from bytesep.utils import StatisticsContainer, read_yaml def get_musdb18_callbacks( config_yaml: str, workspace: str, checkpoints_dir: str, statistics_path: str, logger: pl.loggers.TensorBoardLogger, model: nn.Module, evaluate_device: str, ) -> List[pl.Callback]: r"""Get MUSDB18 callbacks of a config yaml. Args: config_yaml: str workspace: str checkpoints_dir: str, directory to save checkpoints statistics_dir: str, directory to save statistics logger: pl.loggers.TensorBoardLogger model: nn.Module evaluate_device: str Return: callbacks: List[pl.Callback] """ configs = read_yaml(config_yaml) task_name = configs['task_name'] evaluation_callback = configs['train']['evaluation_callback'] target_source_types = configs['train']['target_source_types'] input_channels = configs['train']['channels'] evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name) test_segment_seconds = configs['evaluate']['segment_seconds'] sample_rate = configs['train']['sample_rate'] test_segment_samples = int(test_segment_seconds * sample_rate) test_batch_size = configs['evaluate']['batch_size'] evaluate_step_frequency = configs['train']['evaluate_step_frequency'] save_step_frequency = configs['train']['save_step_frequency'] # save checkpoint callback save_checkpoints_callback = SaveCheckpointsCallback( model=model, checkpoints_dir=checkpoints_dir, save_step_frequency=save_step_frequency, ) # evaluation callback EvaluationCallback = _get_evaluation_callback_class(evaluation_callback) # statistics container statistics_container = StatisticsContainer(statistics_path) # evaluation callback evaluate_train_callback = EvaluationCallback( dataset_dir=evaluation_audios_dir, model=model, target_source_types=target_source_types, input_channels=input_channels, sample_rate=sample_rate, split='train', segment_samples=test_segment_samples, batch_size=test_batch_size, device=evaluate_device, evaluate_step_frequency=evaluate_step_frequency, logger=logger, statistics_container=statistics_container, ) evaluate_test_callback = EvaluationCallback( dataset_dir=evaluation_audios_dir, model=model, target_source_types=target_source_types, input_channels=input_channels, sample_rate=sample_rate, split='test', segment_samples=test_segment_samples, batch_size=test_batch_size, device=evaluate_device, evaluate_step_frequency=evaluate_step_frequency, logger=logger, statistics_container=statistics_container, ) # callbacks = [save_checkpoints_callback, evaluate_train_callback, evaluate_test_callback] callbacks = [save_checkpoints_callback, evaluate_test_callback] return callbacks def _get_evaluation_callback_class(evaluation_callback) -> pl.Callback: r"""Get evaluation callback class.""" if evaluation_callback == "Musdb18EvaluationCallback": return Musdb18EvaluationCallback if evaluation_callback == 'Musdb18ConditionalEvaluationCallback': return Musdb18ConditionalEvaluationCallback else: raise NotImplementedError class Musdb18EvaluationCallback(pl.Callback): def __init__( self, dataset_dir: str, model: nn.Module, target_source_types: str, input_channels: int, split: str, sample_rate: int, segment_samples: int, batch_size: int, device: str, evaluate_step_frequency: int, logger: pl.loggers.TensorBoardLogger, statistics_container: StatisticsContainer, ): r"""Callback to evaluate every #save_step_frequency steps. Args: dataset_dir: str model: nn.Module target_source_types: List[str], e.g., ['vocals', 'bass', ...] input_channels: int split: 'train' | 'test' sample_rate: int segment_samples: int, length of segments to be input to a model, e.g., 44100*30 batch_size, int, e.g., 12 device: str, e.g., 'cuda' evaluate_step_frequency: int, evaluate every #save_step_frequency steps logger: object statistics_container: StatisticsContainer """ self.model = model self.target_source_types = target_source_types self.input_channels = input_channels self.sample_rate = sample_rate self.split = split self.segment_samples = segment_samples self.evaluate_step_frequency = evaluate_step_frequency self.logger = logger self.statistics_container = statistics_container self.mono = input_channels == 1 self.resample_type = "kaiser_fast" self.mus = musdb.DB(root=dataset_dir, subsets=[split]) error_msg = "The directory {} is empty!".format(dataset_dir) assert len(self.mus) > 0, error_msg # separator self.separator = Separator(model, self.segment_samples, batch_size, device) @rank_zero_only def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: r"""Evaluate separation SDRs of audio recordings.""" global_step = trainer.global_step if global_step % self.evaluate_step_frequency == 0: sdr_dict = {} logging.info("--- Step {} ---".format(global_step)) logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks))) eval_time = time.time() for track in self.mus.tracks: audio_name = track.name # Get waveform of mixture. mixture = track.audio.T # (channels_num, audio_samples) mixture = preprocess_audio( audio=mixture, mono=self.mono, origin_sr=track.rate, sr=self.sample_rate, resample_type=self.resample_type, ) # (channels_num, audio_samples) target_dict = {} sdr_dict[audio_name] = {} # Get waveform of all target source types. for j, source_type in enumerate(self.target_source_types): # E.g., ['vocals', 'bass', ...] audio = track.targets[source_type].audio.T audio = preprocess_audio( audio=audio, mono=self.mono, origin_sr=track.rate, sr=self.sample_rate, resample_type=self.resample_type, ) # (channels_num, audio_samples) target_dict[source_type] = audio # (channels_num, audio_samples) # Separate. input_dict = {'waveform': mixture} sep_wavs = self.separator.separate(input_dict) # sep_wavs: (target_sources_num * channels_num, audio_samples) # Post process separation results. sep_wavs = preprocess_audio( audio=sep_wavs, mono=self.mono, origin_sr=self.sample_rate, sr=track.rate, resample_type=self.resample_type, ) # sep_wavs: (target_sources_num * channels_num, audio_samples) sep_wavs = librosa.util.fix_length( sep_wavs, size=mixture.shape[1], axis=1 ) # sep_wavs: (target_sources_num * channels_num, audio_samples) sep_wav_dict = get_separated_wavs_from_simo_output( sep_wavs, self.input_channels, self.target_source_types ) # output_dict: dict, e.g., { # 'vocals': (channels_num, audio_samples), # 'bass': (channels_num, audio_samples), # ..., # } # Evaluate for all target source types. for source_type in self.target_source_types: # E.g., ['vocals', 'bass', ...] # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan). (sdrs, _, _, _) = museval.evaluate( [target_dict[source_type].T], [sep_wav_dict[source_type].T] ) sdr = np.nanmedian(sdrs) sdr_dict[audio_name][source_type] = sdr logging.info( "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr) ) logging.info("-----------------------------") median_sdr_dict = {} # Calculate median SDRs of all songs. for source_type in self.target_source_types: # E.g., ['vocals', 'bass', ...] median_sdr = np.median( [ sdr_dict[audio_name][source_type] for audio_name in sdr_dict.keys() ] ) median_sdr_dict[source_type] = median_sdr logging.info( "Step: {}, {}, Median SDR: {:.3f}".format( global_step, source_type, median_sdr ) ) logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time)) statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict} self.statistics_container.append(global_step, statistics, self.split) self.statistics_container.dump() def get_separated_wavs_from_simo_output(x, input_channels, target_source_types) -> Dict: r"""Get separated waveforms of target sources from a single input multiple output (SIMO) system. Args: x: (target_sources_num * channels_num, audio_samples) input_channels: int target_source_types: List[str], e.g., ['vocals', 'bass', ...] Returns: output_dict: dict, e.g., { 'vocals': (channels_num, audio_samples), 'bass': (channels_num, audio_samples), ..., } """ output_dict = {} for j, source_type in enumerate(target_source_types): output_dict[source_type] = x[j * input_channels : (j + 1) * input_channels] return output_dict class Musdb18ConditionalEvaluationCallback(pl.Callback): def __init__( self, dataset_dir: str, model: nn.Module, target_source_types: str, input_channels: int, split: str, sample_rate: int, segment_samples: int, batch_size: int, device: str, evaluate_step_frequency: int, logger: pl.loggers.TensorBoardLogger, statistics_container: StatisticsContainer, ): r"""Callback to evaluate every #save_step_frequency steps. Args: dataset_dir: str model: nn.Module target_source_types: List[str], e.g., ['vocals', 'bass', ...] input_channels: int split: 'train' | 'test' sample_rate: int segment_samples: int, length of segments to be input to a model, e.g., 44100*30 batch_size, int, e.g., 12 device: str, e.g., 'cuda' evaluate_step_frequency: int, evaluate every #save_step_frequency steps logger: object statistics_container: StatisticsContainer """ self.model = model self.target_source_types = target_source_types self.input_channels = input_channels self.sample_rate = sample_rate self.split = split self.segment_samples = segment_samples self.evaluate_step_frequency = evaluate_step_frequency self.logger = logger self.statistics_container = statistics_container self.mono = input_channels == 1 self.resample_type = "kaiser_fast" self.mus = musdb.DB(root=dataset_dir, subsets=[split]) error_msg = "The directory {} is empty!".format(dataset_dir) assert len(self.mus) > 0, error_msg # separator self.separator = Separator(model, self.segment_samples, batch_size, device) @rank_zero_only def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: r"""Evaluate separation SDRs of audio recordings.""" global_step = trainer.global_step if global_step % self.evaluate_step_frequency == 0: sdr_dict = {} logging.info("--- Step {} ---".format(global_step)) logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks))) eval_time = time.time() for track in self.mus.tracks: audio_name = track.name # Get waveform of mixture. mixture = track.audio.T # (channels_num, audio_samples) mixture = preprocess_audio( audio=mixture, mono=self.mono, origin_sr=track.rate, sr=self.sample_rate, resample_type=self.resample_type, ) # (channels_num, audio_samples) target_dict = {} sdr_dict[audio_name] = {} # Get waveform of all target source types. for j, source_type in enumerate(self.target_source_types): # E.g., ['vocals', 'bass', ...] audio = track.targets[source_type].audio.T audio = preprocess_audio( audio=audio, mono=self.mono, origin_sr=track.rate, sr=self.sample_rate, resample_type=self.resample_type, ) # (channels_num, audio_samples) target_dict[source_type] = audio # (channels_num, audio_samples) condition = np.zeros(len(self.target_source_types)) condition[j] = 1 input_dict = {'waveform': mixture, 'condition': condition} sep_wav = self.separator.separate(input_dict) # sep_wav: (channels_num, audio_samples) sep_wav = preprocess_audio( audio=sep_wav, mono=self.mono, origin_sr=self.sample_rate, sr=track.rate, resample_type=self.resample_type, ) # sep_wav: (channels_num, audio_samples) sep_wav = librosa.util.fix_length( sep_wav, size=mixture.shape[1], axis=1 ) # sep_wav: (target_sources_num * channels_num, audio_samples) # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan) (sdrs, _, _, _) = museval.evaluate( [target_dict[source_type].T], [sep_wav.T] ) sdr = np.nanmedian(sdrs) sdr_dict[audio_name][source_type] = sdr logging.info( "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr) ) logging.info("-----------------------------") median_sdr_dict = {} # Calculate median SDRs of all songs. for source_type in self.target_source_types: median_sdr = np.median( [ sdr_dict[audio_name][source_type] for audio_name in sdr_dict.keys() ] ) median_sdr_dict[source_type] = median_sdr logging.info( "Step: {}, {}, Median SDR: {:.3f}".format( global_step, source_type, median_sdr ) ) logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time)) statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict} self.statistics_container.append(global_step, statistics, self.split) self.statistics_container.dump()