akhaliq3
spaces demo
5019931
import logging
import os
import time
from typing import Dict, List, NoReturn
import librosa
import musdb
import museval
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
from pytorch_lightning.utilities import rank_zero_only
from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
from bytesep.dataset_creation.pack_audios_to_hdf5s.musdb18 import preprocess_audio
from bytesep.inference import Separator
from bytesep.utils import StatisticsContainer, read_yaml
def get_musdb18_callbacks(
config_yaml: str,
workspace: str,
checkpoints_dir: str,
statistics_path: str,
logger: pl.loggers.TensorBoardLogger,
model: nn.Module,
evaluate_device: str,
) -> List[pl.Callback]:
r"""Get MUSDB18 callbacks of a config yaml.
Args:
config_yaml: str
workspace: str
checkpoints_dir: str, directory to save checkpoints
statistics_dir: str, directory to save statistics
logger: pl.loggers.TensorBoardLogger
model: nn.Module
evaluate_device: str
Return:
callbacks: List[pl.Callback]
"""
configs = read_yaml(config_yaml)
task_name = configs['task_name']
evaluation_callback = configs['train']['evaluation_callback']
target_source_types = configs['train']['target_source_types']
input_channels = configs['train']['channels']
evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name)
test_segment_seconds = configs['evaluate']['segment_seconds']
sample_rate = configs['train']['sample_rate']
test_segment_samples = int(test_segment_seconds * sample_rate)
test_batch_size = configs['evaluate']['batch_size']
evaluate_step_frequency = configs['train']['evaluate_step_frequency']
save_step_frequency = configs['train']['save_step_frequency']
# save checkpoint callback
save_checkpoints_callback = SaveCheckpointsCallback(
model=model,
checkpoints_dir=checkpoints_dir,
save_step_frequency=save_step_frequency,
)
# evaluation callback
EvaluationCallback = _get_evaluation_callback_class(evaluation_callback)
# statistics container
statistics_container = StatisticsContainer(statistics_path)
# evaluation callback
evaluate_train_callback = EvaluationCallback(
dataset_dir=evaluation_audios_dir,
model=model,
target_source_types=target_source_types,
input_channels=input_channels,
sample_rate=sample_rate,
split='train',
segment_samples=test_segment_samples,
batch_size=test_batch_size,
device=evaluate_device,
evaluate_step_frequency=evaluate_step_frequency,
logger=logger,
statistics_container=statistics_container,
)
evaluate_test_callback = EvaluationCallback(
dataset_dir=evaluation_audios_dir,
model=model,
target_source_types=target_source_types,
input_channels=input_channels,
sample_rate=sample_rate,
split='test',
segment_samples=test_segment_samples,
batch_size=test_batch_size,
device=evaluate_device,
evaluate_step_frequency=evaluate_step_frequency,
logger=logger,
statistics_container=statistics_container,
)
# callbacks = [save_checkpoints_callback, evaluate_train_callback, evaluate_test_callback]
callbacks = [save_checkpoints_callback, evaluate_test_callback]
return callbacks
def _get_evaluation_callback_class(evaluation_callback) -> pl.Callback:
r"""Get evaluation callback class."""
if evaluation_callback == "Musdb18EvaluationCallback":
return Musdb18EvaluationCallback
if evaluation_callback == 'Musdb18ConditionalEvaluationCallback':
return Musdb18ConditionalEvaluationCallback
else:
raise NotImplementedError
class Musdb18EvaluationCallback(pl.Callback):
def __init__(
self,
dataset_dir: str,
model: nn.Module,
target_source_types: str,
input_channels: int,
split: str,
sample_rate: int,
segment_samples: int,
batch_size: int,
device: str,
evaluate_step_frequency: int,
logger: pl.loggers.TensorBoardLogger,
statistics_container: StatisticsContainer,
):
r"""Callback to evaluate every #save_step_frequency steps.
Args:
dataset_dir: str
model: nn.Module
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
input_channels: int
split: 'train' | 'test'
sample_rate: int
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
batch_size, int, e.g., 12
device: str, e.g., 'cuda'
evaluate_step_frequency: int, evaluate every #save_step_frequency steps
logger: object
statistics_container: StatisticsContainer
"""
self.model = model
self.target_source_types = target_source_types
self.input_channels = input_channels
self.sample_rate = sample_rate
self.split = split
self.segment_samples = segment_samples
self.evaluate_step_frequency = evaluate_step_frequency
self.logger = logger
self.statistics_container = statistics_container
self.mono = input_channels == 1
self.resample_type = "kaiser_fast"
self.mus = musdb.DB(root=dataset_dir, subsets=[split])
error_msg = "The directory {} is empty!".format(dataset_dir)
assert len(self.mus) > 0, error_msg
# separator
self.separator = Separator(model, self.segment_samples, batch_size, device)
@rank_zero_only
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
r"""Evaluate separation SDRs of audio recordings."""
global_step = trainer.global_step
if global_step % self.evaluate_step_frequency == 0:
sdr_dict = {}
logging.info("--- Step {} ---".format(global_step))
logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))
eval_time = time.time()
for track in self.mus.tracks:
audio_name = track.name
# Get waveform of mixture.
mixture = track.audio.T
# (channels_num, audio_samples)
mixture = preprocess_audio(
audio=mixture,
mono=self.mono,
origin_sr=track.rate,
sr=self.sample_rate,
resample_type=self.resample_type,
)
# (channels_num, audio_samples)
target_dict = {}
sdr_dict[audio_name] = {}
# Get waveform of all target source types.
for j, source_type in enumerate(self.target_source_types):
# E.g., ['vocals', 'bass', ...]
audio = track.targets[source_type].audio.T
audio = preprocess_audio(
audio=audio,
mono=self.mono,
origin_sr=track.rate,
sr=self.sample_rate,
resample_type=self.resample_type,
)
# (channels_num, audio_samples)
target_dict[source_type] = audio
# (channels_num, audio_samples)
# Separate.
input_dict = {'waveform': mixture}
sep_wavs = self.separator.separate(input_dict)
# sep_wavs: (target_sources_num * channels_num, audio_samples)
# Post process separation results.
sep_wavs = preprocess_audio(
audio=sep_wavs,
mono=self.mono,
origin_sr=self.sample_rate,
sr=track.rate,
resample_type=self.resample_type,
)
# sep_wavs: (target_sources_num * channels_num, audio_samples)
sep_wavs = librosa.util.fix_length(
sep_wavs, size=mixture.shape[1], axis=1
)
# sep_wavs: (target_sources_num * channels_num, audio_samples)
sep_wav_dict = get_separated_wavs_from_simo_output(
sep_wavs, self.input_channels, self.target_source_types
)
# output_dict: dict, e.g., {
# 'vocals': (channels_num, audio_samples),
# 'bass': (channels_num, audio_samples),
# ...,
# }
# Evaluate for all target source types.
for source_type in self.target_source_types:
# E.g., ['vocals', 'bass', ...]
# Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan).
(sdrs, _, _, _) = museval.evaluate(
[target_dict[source_type].T], [sep_wav_dict[source_type].T]
)
sdr = np.nanmedian(sdrs)
sdr_dict[audio_name][source_type] = sdr
logging.info(
"{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
)
logging.info("-----------------------------")
median_sdr_dict = {}
# Calculate median SDRs of all songs.
for source_type in self.target_source_types:
# E.g., ['vocals', 'bass', ...]
median_sdr = np.median(
[
sdr_dict[audio_name][source_type]
for audio_name in sdr_dict.keys()
]
)
median_sdr_dict[source_type] = median_sdr
logging.info(
"Step: {}, {}, Median SDR: {:.3f}".format(
global_step, source_type, median_sdr
)
)
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
self.statistics_container.append(global_step, statistics, self.split)
self.statistics_container.dump()
def get_separated_wavs_from_simo_output(x, input_channels, target_source_types) -> Dict:
r"""Get separated waveforms of target sources from a single input multiple
output (SIMO) system.
Args:
x: (target_sources_num * channels_num, audio_samples)
input_channels: int
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
Returns:
output_dict: dict, e.g., {
'vocals': (channels_num, audio_samples),
'bass': (channels_num, audio_samples),
...,
}
"""
output_dict = {}
for j, source_type in enumerate(target_source_types):
output_dict[source_type] = x[j * input_channels : (j + 1) * input_channels]
return output_dict
class Musdb18ConditionalEvaluationCallback(pl.Callback):
def __init__(
self,
dataset_dir: str,
model: nn.Module,
target_source_types: str,
input_channels: int,
split: str,
sample_rate: int,
segment_samples: int,
batch_size: int,
device: str,
evaluate_step_frequency: int,
logger: pl.loggers.TensorBoardLogger,
statistics_container: StatisticsContainer,
):
r"""Callback to evaluate every #save_step_frequency steps.
Args:
dataset_dir: str
model: nn.Module
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
input_channels: int
split: 'train' | 'test'
sample_rate: int
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
batch_size, int, e.g., 12
device: str, e.g., 'cuda'
evaluate_step_frequency: int, evaluate every #save_step_frequency steps
logger: object
statistics_container: StatisticsContainer
"""
self.model = model
self.target_source_types = target_source_types
self.input_channels = input_channels
self.sample_rate = sample_rate
self.split = split
self.segment_samples = segment_samples
self.evaluate_step_frequency = evaluate_step_frequency
self.logger = logger
self.statistics_container = statistics_container
self.mono = input_channels == 1
self.resample_type = "kaiser_fast"
self.mus = musdb.DB(root=dataset_dir, subsets=[split])
error_msg = "The directory {} is empty!".format(dataset_dir)
assert len(self.mus) > 0, error_msg
# separator
self.separator = Separator(model, self.segment_samples, batch_size, device)
@rank_zero_only
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
r"""Evaluate separation SDRs of audio recordings."""
global_step = trainer.global_step
if global_step % self.evaluate_step_frequency == 0:
sdr_dict = {}
logging.info("--- Step {} ---".format(global_step))
logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))
eval_time = time.time()
for track in self.mus.tracks:
audio_name = track.name
# Get waveform of mixture.
mixture = track.audio.T
# (channels_num, audio_samples)
mixture = preprocess_audio(
audio=mixture,
mono=self.mono,
origin_sr=track.rate,
sr=self.sample_rate,
resample_type=self.resample_type,
)
# (channels_num, audio_samples)
target_dict = {}
sdr_dict[audio_name] = {}
# Get waveform of all target source types.
for j, source_type in enumerate(self.target_source_types):
# E.g., ['vocals', 'bass', ...]
audio = track.targets[source_type].audio.T
audio = preprocess_audio(
audio=audio,
mono=self.mono,
origin_sr=track.rate,
sr=self.sample_rate,
resample_type=self.resample_type,
)
# (channels_num, audio_samples)
target_dict[source_type] = audio
# (channels_num, audio_samples)
condition = np.zeros(len(self.target_source_types))
condition[j] = 1
input_dict = {'waveform': mixture, 'condition': condition}
sep_wav = self.separator.separate(input_dict)
# sep_wav: (channels_num, audio_samples)
sep_wav = preprocess_audio(
audio=sep_wav,
mono=self.mono,
origin_sr=self.sample_rate,
sr=track.rate,
resample_type=self.resample_type,
)
# sep_wav: (channels_num, audio_samples)
sep_wav = librosa.util.fix_length(
sep_wav, size=mixture.shape[1], axis=1
)
# sep_wav: (target_sources_num * channels_num, audio_samples)
# Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan)
(sdrs, _, _, _) = museval.evaluate(
[target_dict[source_type].T], [sep_wav.T]
)
sdr = np.nanmedian(sdrs)
sdr_dict[audio_name][source_type] = sdr
logging.info(
"{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
)
logging.info("-----------------------------")
median_sdr_dict = {}
# Calculate median SDRs of all songs.
for source_type in self.target_source_types:
median_sdr = np.median(
[
sdr_dict[audio_name][source_type]
for audio_name in sdr_dict.keys()
]
)
median_sdr_dict[source_type] = median_sdr
logging.info(
"Step: {}, {}, Median SDR: {:.3f}".format(
global_step, source_type, median_sdr
)
)
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
self.statistics_container.append(global_step, statistics, self.split)
self.statistics_container.dump()