diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..926a8762f402b00d20b0c2396c2a4caf56655779 --- /dev/null +++ b/LICENSE @@ -0,0 +1,13 @@ +Copyright 2021 ByteDance + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/bytesep/__init__.py b/bytesep/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d2ec7c5efc3fbf7a79935c044345530663296d3 --- /dev/null +++ b/bytesep/__init__.py @@ -0,0 +1 @@ +from bytesep.inference import Separator diff --git a/bytesep/callbacks/__init__.py b/bytesep/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e70c6c2a4fa8fcabfdb78502907d431b07158edc --- /dev/null +++ b/bytesep/callbacks/__init__.py @@ -0,0 +1,76 @@ +from typing import List + +import pytorch_lightning as pl +import torch.nn as nn + + +def get_callbacks( + task_name: str, + config_yaml: str, + workspace: str, + checkpoints_dir: str, + statistics_path: str, + logger: pl.loggers.TensorBoardLogger, + model: nn.Module, + evaluate_device: str, +) -> List[pl.Callback]: + r"""Get callbacks of a task and config yaml file. + + Args: + task_name: str + config_yaml: str + dataset_dir: str + workspace: str, containing useful files such as audios for evaluation + checkpoints_dir: str, directory to save checkpoints + statistics_dir: str, directory to save statistics + logger: pl.loggers.TensorBoardLogger + model: nn.Module + evaluate_device: str + + Return: + callbacks: List[pl.Callback] + """ + if task_name == 'musdb18': + + from bytesep.callbacks.musdb18 import get_musdb18_callbacks + + return get_musdb18_callbacks( + config_yaml=config_yaml, + workspace=workspace, + checkpoints_dir=checkpoints_dir, + statistics_path=statistics_path, + logger=logger, + model=model, + evaluate_device=evaluate_device, + ) + + elif task_name == 'voicebank-demand': + + from bytesep.callbacks.voicebank_demand import get_voicebank_demand_callbacks + + return get_voicebank_demand_callbacks( + config_yaml=config_yaml, + workspace=workspace, + checkpoints_dir=checkpoints_dir, + statistics_path=statistics_path, + logger=logger, + model=model, + evaluate_device=evaluate_device, + ) + + elif task_name in ['vctk-musdb18', 'violin-piano', 'piano-symphony']: + + from bytesep.callbacks.instruments_callbacks import get_instruments_callbacks + + return get_instruments_callbacks( + config_yaml=config_yaml, + workspace=workspace, + checkpoints_dir=checkpoints_dir, + statistics_path=statistics_path, + logger=logger, + model=model, + evaluate_device=evaluate_device, + ) + + else: + raise NotImplementedError diff --git a/bytesep/callbacks/base_callbacks.py b/bytesep/callbacks/base_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..ef62dd591f1516aa41e2ba347cc3aaa558854f8d --- /dev/null +++ b/bytesep/callbacks/base_callbacks.py @@ -0,0 +1,44 @@ +import logging +import os +from typing import NoReturn + +import pytorch_lightning as pl +import torch +import torch.nn as nn +from pytorch_lightning.utilities import rank_zero_only + + +class SaveCheckpointsCallback(pl.Callback): + def __init__( + self, + model: nn.Module, + checkpoints_dir: str, + save_step_frequency: int, + ): + r"""Callback to save checkpoints every #save_step_frequency steps. + + Args: + model: nn.Module + checkpoints_dir: str, directory to save checkpoints + save_step_frequency: int + """ + self.model = model + self.checkpoints_dir = checkpoints_dir + self.save_step_frequency = save_step_frequency + os.makedirs(self.checkpoints_dir, exist_ok=True) + + @rank_zero_only + def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: + r"""Save checkpoint.""" + global_step = trainer.global_step + + if global_step % self.save_step_frequency == 0: + + checkpoint_path = os.path.join( + self.checkpoints_dir, "step={}.pth".format(global_step) + ) + + checkpoint = {'step': global_step, 'model': self.model.state_dict()} + + torch.save(checkpoint, checkpoint_path) + logging.info("Save checkpoint to {}".format(checkpoint_path)) diff --git a/bytesep/callbacks/instruments_callbacks.py b/bytesep/callbacks/instruments_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8a1d133ac4a9253c207cb2d6607fb96d392607 --- /dev/null +++ b/bytesep/callbacks/instruments_callbacks.py @@ -0,0 +1,200 @@ +import logging +import os +import time +from typing import List, NoReturn + +import librosa +import numpy as np +import pytorch_lightning as pl +import torch.nn as nn +from pytorch_lightning.utilities import rank_zero_only + +from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback +from bytesep.inference import Separator +from bytesep.utils import StatisticsContainer, calculate_sdr, read_yaml + + +def get_instruments_callbacks( + config_yaml: str, + workspace: str, + checkpoints_dir: str, + statistics_path: str, + logger: pl.loggers.TensorBoardLogger, + model: nn.Module, + evaluate_device: str, +) -> List[pl.Callback]: + """Get Voicebank-Demand callbacks of a config yaml. + + Args: + config_yaml: str + workspace: str + checkpoints_dir: str, directory to save checkpoints + statistics_dir: str, directory to save statistics + logger: pl.loggers.TensorBoardLogger + model: nn.Module + evaluate_device: str + + Return: + callbacks: List[pl.Callback] + """ + configs = read_yaml(config_yaml) + task_name = configs['task_name'] + target_source_types = configs['train']['target_source_types'] + input_channels = configs['train']['channels'] + mono = True if input_channels == 1 else False + test_audios_dir = os.path.join(workspace, "evaluation_audios", task_name, "test") + sample_rate = configs['train']['sample_rate'] + evaluate_step_frequency = configs['train']['evaluate_step_frequency'] + save_step_frequency = configs['train']['save_step_frequency'] + test_batch_size = configs['evaluate']['batch_size'] + test_segment_seconds = configs['evaluate']['segment_seconds'] + + test_segment_samples = int(test_segment_seconds * sample_rate) + assert len(target_source_types) == 1 + target_source_type = target_source_types[0] + + # save checkpoint callback + save_checkpoints_callback = SaveCheckpointsCallback( + model=model, + checkpoints_dir=checkpoints_dir, + save_step_frequency=save_step_frequency, + ) + + # statistics container + statistics_container = StatisticsContainer(statistics_path) + + # evaluation callback + evaluate_test_callback = EvaluationCallback( + model=model, + target_source_type=target_source_type, + input_channels=input_channels, + sample_rate=sample_rate, + mono=mono, + evaluation_audios_dir=test_audios_dir, + segment_samples=test_segment_samples, + batch_size=test_batch_size, + device=evaluate_device, + evaluate_step_frequency=evaluate_step_frequency, + logger=logger, + statistics_container=statistics_container, + ) + + callbacks = [save_checkpoints_callback, evaluate_test_callback] + # callbacks = [save_checkpoints_callback] + + return callbacks + + +class EvaluationCallback(pl.Callback): + def __init__( + self, + model: nn.Module, + input_channels: int, + evaluation_audios_dir: str, + target_source_type: str, + sample_rate: int, + mono: bool, + segment_samples: int, + batch_size: int, + device: str, + evaluate_step_frequency: int, + logger: pl.loggers.TensorBoardLogger, + statistics_container: StatisticsContainer, + ): + r"""Callback to evaluate every #save_step_frequency steps. + + Args: + model: nn.Module + input_channels: int + evaluation_audios_dir: str, directory containing audios for evaluation + target_source_type: str, e.g., 'violin' + sample_rate: int + mono: bool + segment_samples: int, length of segments to be input to a model, e.g., 44100*30 + batch_size, int, e.g., 12 + device: str, e.g., 'cuda' + evaluate_step_frequency: int, evaluate every #save_step_frequency steps + logger: pl.loggers.TensorBoardLogger + statistics_container: StatisticsContainer + """ + self.model = model + self.target_source_type = target_source_type + self.sample_rate = sample_rate + self.mono = mono + self.segment_samples = segment_samples + self.evaluate_step_frequency = evaluate_step_frequency + self.logger = logger + self.statistics_container = statistics_container + + self.evaluation_audios_dir = evaluation_audios_dir + + # separator + self.separator = Separator(model, self.segment_samples, batch_size, device) + + @rank_zero_only + def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: + r"""Evaluate losses on a few mini-batches. Losses are only used for + observing training, and are not final F1 metrics. + """ + + global_step = trainer.global_step + + if global_step % self.evaluate_step_frequency == 0: + + mixture_audios_dir = os.path.join(self.evaluation_audios_dir, 'mixture') + clean_audios_dir = os.path.join( + self.evaluation_audios_dir, self.target_source_type + ) + + audio_names = sorted(os.listdir(mixture_audios_dir)) + + error_str = "Directory {} does not contain audios for evaluation!".format( + self.evaluation_audios_dir + ) + assert len(audio_names) > 0, error_str + + logging.info("--- Step {} ---".format(global_step)) + logging.info("Total {} pieces for evaluation:".format(len(audio_names))) + + eval_time = time.time() + + sdrs = [] + + for n, audio_name in enumerate(audio_names): + + # Load audio. + mixture_path = os.path.join(mixture_audios_dir, audio_name) + clean_path = os.path.join(clean_audios_dir, audio_name) + + mixture, origin_fs = librosa.core.load( + mixture_path, sr=self.sample_rate, mono=self.mono + ) + + # Target + clean, origin_fs = librosa.core.load( + clean_path, sr=self.sample_rate, mono=self.mono + ) + + if mixture.ndim == 1: + mixture = mixture[None, :] + # (channels_num, audio_length) + + input_dict = {'waveform': mixture} + + # separate + sep_wav = self.separator.separate(input_dict) + # (channels_num, audio_length) + + sdr = calculate_sdr(ref=clean, est=sep_wav) + + print("{} SDR: {:.3f}".format(audio_name, sdr)) + sdrs.append(sdr) + + logging.info("-----------------------------") + logging.info('Avg SDR: {:.3f}'.format(np.mean(sdrs))) + + logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time)) + + statistics = {"sdr": np.mean(sdrs)} + self.statistics_container.append(global_step, statistics, 'test') + self.statistics_container.dump() diff --git a/bytesep/callbacks/musdb18.py b/bytesep/callbacks/musdb18.py new file mode 100644 index 0000000000000000000000000000000000000000..37a8a65b6005efa5671d05044593d3805c289897 --- /dev/null +++ b/bytesep/callbacks/musdb18.py @@ -0,0 +1,485 @@ +import logging +import os +import time +from typing import Dict, List, NoReturn + +import librosa +import musdb +import museval +import numpy as np +import pytorch_lightning as pl +import torch.nn as nn +from pytorch_lightning.utilities import rank_zero_only + +from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback +from bytesep.dataset_creation.pack_audios_to_hdf5s.musdb18 import preprocess_audio +from bytesep.inference import Separator +from bytesep.utils import StatisticsContainer, read_yaml + + +def get_musdb18_callbacks( + config_yaml: str, + workspace: str, + checkpoints_dir: str, + statistics_path: str, + logger: pl.loggers.TensorBoardLogger, + model: nn.Module, + evaluate_device: str, +) -> List[pl.Callback]: + r"""Get MUSDB18 callbacks of a config yaml. + + Args: + config_yaml: str + workspace: str + checkpoints_dir: str, directory to save checkpoints + statistics_dir: str, directory to save statistics + logger: pl.loggers.TensorBoardLogger + model: nn.Module + evaluate_device: str + + Return: + callbacks: List[pl.Callback] + """ + configs = read_yaml(config_yaml) + task_name = configs['task_name'] + evaluation_callback = configs['train']['evaluation_callback'] + target_source_types = configs['train']['target_source_types'] + input_channels = configs['train']['channels'] + evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name) + test_segment_seconds = configs['evaluate']['segment_seconds'] + sample_rate = configs['train']['sample_rate'] + test_segment_samples = int(test_segment_seconds * sample_rate) + test_batch_size = configs['evaluate']['batch_size'] + + evaluate_step_frequency = configs['train']['evaluate_step_frequency'] + save_step_frequency = configs['train']['save_step_frequency'] + + # save checkpoint callback + save_checkpoints_callback = SaveCheckpointsCallback( + model=model, + checkpoints_dir=checkpoints_dir, + save_step_frequency=save_step_frequency, + ) + + # evaluation callback + EvaluationCallback = _get_evaluation_callback_class(evaluation_callback) + + # statistics container + statistics_container = StatisticsContainer(statistics_path) + + # evaluation callback + evaluate_train_callback = EvaluationCallback( + dataset_dir=evaluation_audios_dir, + model=model, + target_source_types=target_source_types, + input_channels=input_channels, + sample_rate=sample_rate, + split='train', + segment_samples=test_segment_samples, + batch_size=test_batch_size, + device=evaluate_device, + evaluate_step_frequency=evaluate_step_frequency, + logger=logger, + statistics_container=statistics_container, + ) + + evaluate_test_callback = EvaluationCallback( + dataset_dir=evaluation_audios_dir, + model=model, + target_source_types=target_source_types, + input_channels=input_channels, + sample_rate=sample_rate, + split='test', + segment_samples=test_segment_samples, + batch_size=test_batch_size, + device=evaluate_device, + evaluate_step_frequency=evaluate_step_frequency, + logger=logger, + statistics_container=statistics_container, + ) + + # callbacks = [save_checkpoints_callback, evaluate_train_callback, evaluate_test_callback] + callbacks = [save_checkpoints_callback, evaluate_test_callback] + + return callbacks + + +def _get_evaluation_callback_class(evaluation_callback) -> pl.Callback: + r"""Get evaluation callback class.""" + if evaluation_callback == "Musdb18EvaluationCallback": + return Musdb18EvaluationCallback + + if evaluation_callback == 'Musdb18ConditionalEvaluationCallback': + return Musdb18ConditionalEvaluationCallback + + else: + raise NotImplementedError + + +class Musdb18EvaluationCallback(pl.Callback): + def __init__( + self, + dataset_dir: str, + model: nn.Module, + target_source_types: str, + input_channels: int, + split: str, + sample_rate: int, + segment_samples: int, + batch_size: int, + device: str, + evaluate_step_frequency: int, + logger: pl.loggers.TensorBoardLogger, + statistics_container: StatisticsContainer, + ): + r"""Callback to evaluate every #save_step_frequency steps. + + Args: + dataset_dir: str + model: nn.Module + target_source_types: List[str], e.g., ['vocals', 'bass', ...] + input_channels: int + split: 'train' | 'test' + sample_rate: int + segment_samples: int, length of segments to be input to a model, e.g., 44100*30 + batch_size, int, e.g., 12 + device: str, e.g., 'cuda' + evaluate_step_frequency: int, evaluate every #save_step_frequency steps + logger: object + statistics_container: StatisticsContainer + """ + self.model = model + self.target_source_types = target_source_types + self.input_channels = input_channels + self.sample_rate = sample_rate + self.split = split + self.segment_samples = segment_samples + self.evaluate_step_frequency = evaluate_step_frequency + self.logger = logger + self.statistics_container = statistics_container + self.mono = input_channels == 1 + self.resample_type = "kaiser_fast" + + self.mus = musdb.DB(root=dataset_dir, subsets=[split]) + + error_msg = "The directory {} is empty!".format(dataset_dir) + assert len(self.mus) > 0, error_msg + + # separator + self.separator = Separator(model, self.segment_samples, batch_size, device) + + @rank_zero_only + def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: + r"""Evaluate separation SDRs of audio recordings.""" + global_step = trainer.global_step + + if global_step % self.evaluate_step_frequency == 0: + + sdr_dict = {} + + logging.info("--- Step {} ---".format(global_step)) + logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks))) + + eval_time = time.time() + + for track in self.mus.tracks: + + audio_name = track.name + + # Get waveform of mixture. + mixture = track.audio.T + # (channels_num, audio_samples) + + mixture = preprocess_audio( + audio=mixture, + mono=self.mono, + origin_sr=track.rate, + sr=self.sample_rate, + resample_type=self.resample_type, + ) + # (channels_num, audio_samples) + + target_dict = {} + sdr_dict[audio_name] = {} + + # Get waveform of all target source types. + for j, source_type in enumerate(self.target_source_types): + # E.g., ['vocals', 'bass', ...] + + audio = track.targets[source_type].audio.T + + audio = preprocess_audio( + audio=audio, + mono=self.mono, + origin_sr=track.rate, + sr=self.sample_rate, + resample_type=self.resample_type, + ) + # (channels_num, audio_samples) + + target_dict[source_type] = audio + # (channels_num, audio_samples) + + # Separate. + input_dict = {'waveform': mixture} + + sep_wavs = self.separator.separate(input_dict) + # sep_wavs: (target_sources_num * channels_num, audio_samples) + + # Post process separation results. + sep_wavs = preprocess_audio( + audio=sep_wavs, + mono=self.mono, + origin_sr=self.sample_rate, + sr=track.rate, + resample_type=self.resample_type, + ) + # sep_wavs: (target_sources_num * channels_num, audio_samples) + + sep_wavs = librosa.util.fix_length( + sep_wavs, size=mixture.shape[1], axis=1 + ) + # sep_wavs: (target_sources_num * channels_num, audio_samples) + + sep_wav_dict = get_separated_wavs_from_simo_output( + sep_wavs, self.input_channels, self.target_source_types + ) + # output_dict: dict, e.g., { + # 'vocals': (channels_num, audio_samples), + # 'bass': (channels_num, audio_samples), + # ..., + # } + + # Evaluate for all target source types. + for source_type in self.target_source_types: + # E.g., ['vocals', 'bass', ...] + + # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan). + (sdrs, _, _, _) = museval.evaluate( + [target_dict[source_type].T], [sep_wav_dict[source_type].T] + ) + + sdr = np.nanmedian(sdrs) + sdr_dict[audio_name][source_type] = sdr + + logging.info( + "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr) + ) + + logging.info("-----------------------------") + median_sdr_dict = {} + + # Calculate median SDRs of all songs. + for source_type in self.target_source_types: + # E.g., ['vocals', 'bass', ...] + + median_sdr = np.median( + [ + sdr_dict[audio_name][source_type] + for audio_name in sdr_dict.keys() + ] + ) + + median_sdr_dict[source_type] = median_sdr + + logging.info( + "Step: {}, {}, Median SDR: {:.3f}".format( + global_step, source_type, median_sdr + ) + ) + + logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time)) + + statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict} + self.statistics_container.append(global_step, statistics, self.split) + self.statistics_container.dump() + + +def get_separated_wavs_from_simo_output(x, input_channels, target_source_types) -> Dict: + r"""Get separated waveforms of target sources from a single input multiple + output (SIMO) system. + + Args: + x: (target_sources_num * channels_num, audio_samples) + input_channels: int + target_source_types: List[str], e.g., ['vocals', 'bass', ...] + + Returns: + output_dict: dict, e.g., { + 'vocals': (channels_num, audio_samples), + 'bass': (channels_num, audio_samples), + ..., + } + """ + output_dict = {} + + for j, source_type in enumerate(target_source_types): + output_dict[source_type] = x[j * input_channels : (j + 1) * input_channels] + + return output_dict + + +class Musdb18ConditionalEvaluationCallback(pl.Callback): + def __init__( + self, + dataset_dir: str, + model: nn.Module, + target_source_types: str, + input_channels: int, + split: str, + sample_rate: int, + segment_samples: int, + batch_size: int, + device: str, + evaluate_step_frequency: int, + logger: pl.loggers.TensorBoardLogger, + statistics_container: StatisticsContainer, + ): + r"""Callback to evaluate every #save_step_frequency steps. + + Args: + dataset_dir: str + model: nn.Module + target_source_types: List[str], e.g., ['vocals', 'bass', ...] + input_channels: int + split: 'train' | 'test' + sample_rate: int + segment_samples: int, length of segments to be input to a model, e.g., 44100*30 + batch_size, int, e.g., 12 + device: str, e.g., 'cuda' + evaluate_step_frequency: int, evaluate every #save_step_frequency steps + logger: object + statistics_container: StatisticsContainer + """ + self.model = model + self.target_source_types = target_source_types + self.input_channels = input_channels + self.sample_rate = sample_rate + self.split = split + self.segment_samples = segment_samples + self.evaluate_step_frequency = evaluate_step_frequency + self.logger = logger + self.statistics_container = statistics_container + self.mono = input_channels == 1 + self.resample_type = "kaiser_fast" + + self.mus = musdb.DB(root=dataset_dir, subsets=[split]) + + error_msg = "The directory {} is empty!".format(dataset_dir) + assert len(self.mus) > 0, error_msg + + # separator + self.separator = Separator(model, self.segment_samples, batch_size, device) + + @rank_zero_only + def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: + r"""Evaluate separation SDRs of audio recordings.""" + global_step = trainer.global_step + + if global_step % self.evaluate_step_frequency == 0: + + sdr_dict = {} + + logging.info("--- Step {} ---".format(global_step)) + logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks))) + + eval_time = time.time() + + for track in self.mus.tracks: + + audio_name = track.name + + # Get waveform of mixture. + mixture = track.audio.T + # (channels_num, audio_samples) + + mixture = preprocess_audio( + audio=mixture, + mono=self.mono, + origin_sr=track.rate, + sr=self.sample_rate, + resample_type=self.resample_type, + ) + # (channels_num, audio_samples) + + target_dict = {} + sdr_dict[audio_name] = {} + + # Get waveform of all target source types. + for j, source_type in enumerate(self.target_source_types): + # E.g., ['vocals', 'bass', ...] + + audio = track.targets[source_type].audio.T + + audio = preprocess_audio( + audio=audio, + mono=self.mono, + origin_sr=track.rate, + sr=self.sample_rate, + resample_type=self.resample_type, + ) + # (channels_num, audio_samples) + + target_dict[source_type] = audio + # (channels_num, audio_samples) + + condition = np.zeros(len(self.target_source_types)) + condition[j] = 1 + + input_dict = {'waveform': mixture, 'condition': condition} + + sep_wav = self.separator.separate(input_dict) + # sep_wav: (channels_num, audio_samples) + + sep_wav = preprocess_audio( + audio=sep_wav, + mono=self.mono, + origin_sr=self.sample_rate, + sr=track.rate, + resample_type=self.resample_type, + ) + # sep_wav: (channels_num, audio_samples) + + sep_wav = librosa.util.fix_length( + sep_wav, size=mixture.shape[1], axis=1 + ) + # sep_wav: (target_sources_num * channels_num, audio_samples) + + # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan) + (sdrs, _, _, _) = museval.evaluate( + [target_dict[source_type].T], [sep_wav.T] + ) + + sdr = np.nanmedian(sdrs) + sdr_dict[audio_name][source_type] = sdr + + logging.info( + "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr) + ) + + logging.info("-----------------------------") + median_sdr_dict = {} + + # Calculate median SDRs of all songs. + for source_type in self.target_source_types: + + median_sdr = np.median( + [ + sdr_dict[audio_name][source_type] + for audio_name in sdr_dict.keys() + ] + ) + + median_sdr_dict[source_type] = median_sdr + + logging.info( + "Step: {}, {}, Median SDR: {:.3f}".format( + global_step, source_type, median_sdr + ) + ) + + logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time)) + + statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict} + self.statistics_container.append(global_step, statistics, self.split) + self.statistics_container.dump() diff --git a/bytesep/callbacks/voicebank_demand.py b/bytesep/callbacks/voicebank_demand.py new file mode 100644 index 0000000000000000000000000000000000000000..7041596cdc9b36585c119b582176e5690c9930e7 --- /dev/null +++ b/bytesep/callbacks/voicebank_demand.py @@ -0,0 +1,231 @@ +import logging +import os +import time +from typing import List, NoReturn + +import librosa +import numpy as np +import pysepm +import pytorch_lightning as pl +import torch.nn as nn +from pesq import pesq +from pytorch_lightning.utilities import rank_zero_only + +from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback +from bytesep.inference import Separator +from bytesep.utils import StatisticsContainer, read_yaml + + +def get_voicebank_demand_callbacks( + config_yaml: str, + workspace: str, + checkpoints_dir: str, + statistics_path: str, + logger: pl.loggers.TensorBoardLogger, + model: nn.Module, + evaluate_device: str, +) -> List[pl.Callback]: + """Get Voicebank-Demand callbacks of a config yaml. + + Args: + config_yaml: str + workspace: str + checkpoints_dir: str, directory to save checkpoints + statistics_dir: str, directory to save statistics + logger: pl.loggers.TensorBoardLogger + model: nn.Module + evaluate_device: str + + Return: + callbacks: List[pl.Callback] + """ + configs = read_yaml(config_yaml) + task_name = configs['task_name'] + target_source_types = configs['train']['target_source_types'] + input_channels = configs['train']['channels'] + evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name) + sample_rate = configs['train']['sample_rate'] + evaluate_step_frequency = configs['train']['evaluate_step_frequency'] + save_step_frequency = configs['train']['save_step_frequency'] + test_batch_size = configs['evaluate']['batch_size'] + test_segment_seconds = configs['evaluate']['segment_seconds'] + + test_segment_samples = int(test_segment_seconds * sample_rate) + assert len(target_source_types) == 1 + target_source_type = target_source_types[0] + assert target_source_type == 'speech' + + # save checkpoint callback + save_checkpoints_callback = SaveCheckpointsCallback( + model=model, + checkpoints_dir=checkpoints_dir, + save_step_frequency=save_step_frequency, + ) + + # statistics container + statistics_container = StatisticsContainer(statistics_path) + + # evaluation callback + evaluate_test_callback = EvaluationCallback( + model=model, + input_channels=input_channels, + sample_rate=sample_rate, + evaluation_audios_dir=evaluation_audios_dir, + segment_samples=test_segment_samples, + batch_size=test_batch_size, + device=evaluate_device, + evaluate_step_frequency=evaluate_step_frequency, + logger=logger, + statistics_container=statistics_container, + ) + + callbacks = [save_checkpoints_callback, evaluate_test_callback] + + return callbacks + + +class EvaluationCallback(pl.Callback): + def __init__( + self, + model: nn.Module, + input_channels: int, + evaluation_audios_dir, + sample_rate: int, + segment_samples: int, + batch_size: int, + device: str, + evaluate_step_frequency: int, + logger: pl.loggers.TensorBoardLogger, + statistics_container: StatisticsContainer, + ): + r"""Callback to evaluate every #save_step_frequency steps. + + Args: + model: nn.Module + input_channels: int + evaluation_audios_dir: str, directory containing audios for evaluation + sample_rate: int + segment_samples: int, length of segments to be input to a model, e.g., 44100*30 + batch_size, int, e.g., 12 + device: str, e.g., 'cuda' + evaluate_step_frequency: int, evaluate every #save_step_frequency steps + logger: pl.loggers.TensorBoardLogger + statistics_container: StatisticsContainer + """ + self.model = model + self.mono = True + self.sample_rate = sample_rate + self.segment_samples = segment_samples + self.evaluate_step_frequency = evaluate_step_frequency + self.logger = logger + self.statistics_container = statistics_container + + self.clean_dir = os.path.join(evaluation_audios_dir, "clean_testset_wav") + self.noisy_dir = os.path.join(evaluation_audios_dir, "noisy_testset_wav") + + self.EVALUATION_SAMPLE_RATE = 16000 # Evaluation sample rate of the + # Voicebank-Demand task. + + # separator + self.separator = Separator(model, self.segment_samples, batch_size, device) + + @rank_zero_only + def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: + r"""Evaluate losses on a few mini-batches. Losses are only used for + observing training, and are not final F1 metrics. + """ + + global_step = trainer.global_step + + if global_step % self.evaluate_step_frequency == 0: + + audio_names = sorted( + [ + audio_name + for audio_name in sorted(os.listdir(self.clean_dir)) + if audio_name.endswith('.wav') + ] + ) + + error_str = "Directory {} does not contain audios for evaluation!".format( + self.clean_dir + ) + assert len(audio_names) > 0, error_str + + pesqs, csigs, cbaks, covls, ssnrs = [], [], [], [], [] + + logging.info("--- Step {} ---".format(global_step)) + logging.info("Total {} pieces for evaluation:".format(len(audio_names))) + + eval_time = time.time() + + for n, audio_name in enumerate(audio_names): + + # Load audio. + clean_path = os.path.join(self.clean_dir, audio_name) + mixture_path = os.path.join(self.noisy_dir, audio_name) + + mixture, _ = librosa.core.load( + mixture_path, sr=self.sample_rate, mono=self.mono + ) + + if mixture.ndim == 1: + mixture = mixture[None, :] + # (channels_num, audio_length) + + # Separate. + input_dict = {'waveform': mixture} + + sep_wav = self.separator.separate(input_dict) + # (channels_num, audio_length) + + # Target + clean, _ = librosa.core.load( + clean_path, sr=self.EVALUATION_SAMPLE_RATE, mono=self.mono + ) + + # to mono + sep_wav = np.squeeze(sep_wav) + + # Resample for evaluation. + sep_wav = librosa.resample( + sep_wav, + orig_sr=self.sample_rate, + target_sr=self.EVALUATION_SAMPLE_RATE, + ) + + sep_wav = librosa.util.fix_length(sep_wav, size=len(clean), axis=0) + # (channels, audio_length) + + # Evaluate metrics + pesq_ = pesq(self.EVALUATION_SAMPLE_RATE, clean, sep_wav, 'wb') + + (csig, cbak, covl) = pysepm.composite( + clean, sep_wav, self.EVALUATION_SAMPLE_RATE + ) + + ssnr = pysepm.SNRseg(clean, sep_wav, self.EVALUATION_SAMPLE_RATE) + + pesqs.append(pesq_) + csigs.append(csig) + cbaks.append(cbak) + covls.append(covl) + ssnrs.append(ssnr) + print( + '{}, {}, PESQ: {:.3f}, CSIG: {:.3f}, CBAK: {:.3f}, COVL: {:.3f}, SSNR: {:.3f}'.format( + n, audio_name, pesq_, csig, cbak, covl, ssnr + ) + ) + + logging.info("-----------------------------") + logging.info('Avg PESQ: {:.3f}'.format(np.mean(pesqs))) + logging.info('Avg CSIG: {:.3f}'.format(np.mean(csigs))) + logging.info('Avg CBAK: {:.3f}'.format(np.mean(cbaks))) + logging.info('Avg COVL: {:.3f}'.format(np.mean(covls))) + logging.info('Avg SSNR: {:.3f}'.format(np.mean(ssnrs))) + + logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time)) + + statistics = {"pesq": np.mean(pesqs)} + self.statistics_container.append(global_step, statistics, 'test') + self.statistics_container.dump() diff --git a/bytesep/data/__init__.py b/bytesep/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bytesep/data/augmentors.py b/bytesep/data/augmentors.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d40731d4a8f8123a29351b1f97c21f47cb6dc4 --- /dev/null +++ b/bytesep/data/augmentors.py @@ -0,0 +1,157 @@ +from typing import Dict + +import librosa +import numpy as np + +from bytesep.utils import db_to_magnitude, get_pitch_shift_factor, magnitude_to_db + + +class Augmentor: + def __init__(self, augmentations: Dict, random_seed=1234): + r"""Augmentor for data augmentation of a waveform. + + Args: + augmentations: Dict, e.g, { + 'mixaudio': {'vocals': 2, 'accompaniment': 2} + 'pitch_shift': {'vocals': 4, 'accompaniment': 4}, + ..., + } + random_seed: int + """ + self.augmentations = augmentations + self.random_state = np.random.RandomState(random_seed) + + def __call__(self, waveform: np.array, source_type: str) -> np.array: + r"""Augment a waveform. + + Args: + waveform: (channels_num, audio_samples) + source_type: str + + Returns: + new_waveform: (channels_num, new_audio_samples) + """ + if 'pitch_shift' in self.augmentations.keys(): + waveform = self.pitch_shift(waveform, source_type) + + if 'magnitude_scale' in self.augmentations.keys(): + waveform = self.magnitude_scale(waveform, source_type) + + if 'swap_channel' in self.augmentations.keys(): + waveform = self.swap_channel(waveform, source_type) + + if 'flip_axis' in self.augmentations.keys(): + waveform = self.flip_axis(waveform, source_type) + + return waveform + + def pitch_shift(self, waveform: np.array, source_type: str) -> np.array: + r"""Shift the pitch of a waveform. We use resampling for fast pitch + shifting, so the speed will also be chaneged. The length of the returned + waveform will be changed. + + Args: + waveform: (channels_num, audio_samples) + source_type: str + + Returns: + new_waveform: (channels_num, new_audio_samples) + """ + + # maximum pitch shift in semitones + max_pitch_shift = self.augmentations['pitch_shift'][source_type] + + if max_pitch_shift == 0: # No pitch shift augmentations. + return waveform + + # random pitch shift + rand_pitch = self.random_state.uniform( + low=-max_pitch_shift, high=max_pitch_shift + ) + + # We use librosa.resample instead of librosa.effects.pitch_shift + # because it is 10x times faster. + pitch_shift_factor = get_pitch_shift_factor(rand_pitch) + dummy_sample_rate = 10000 # Dummy constant. + + channels_num = waveform.shape[0] + + if channels_num == 1: + waveform = np.squeeze(waveform) + + new_waveform = librosa.resample( + y=waveform, + orig_sr=dummy_sample_rate, + target_sr=dummy_sample_rate / pitch_shift_factor, + res_type='linear', + axis=-1, + ) + + if channels_num == 1: + new_waveform = new_waveform[None, :] + + return new_waveform + + def magnitude_scale(self, waveform: np.array, source_type: str) -> np.array: + r"""Scale the magnitude of a waveform. + + Args: + waveform: (channels_num, audio_samples) + source_type: str + + Returns: + new_waveform: (channels_num, audio_samples) + """ + lower_db = self.augmentations['magnitude_scale'][source_type]['lower_db'] + higher_db = self.augmentations['magnitude_scale'][source_type]['higher_db'] + + if lower_db == 0 and higher_db == 0: # No magnitude scale augmentation. + return waveform + + # The magnitude (in dB) of the sample with the maximum value. + waveform_db = magnitude_to_db(np.max(np.abs(waveform))) + + new_waveform_db = self.random_state.uniform( + waveform_db + lower_db, min(waveform_db + higher_db, 0) + ) + + relative_db = new_waveform_db - waveform_db + + relative_scale = db_to_magnitude(relative_db) + + new_waveform = waveform * relative_scale + + return new_waveform + + def swap_channel(self, waveform: np.array, source_type: str) -> np.array: + r"""Randomly swap channels. + + Args: + waveform: (channels_num, audio_samples) + source_type: str + + Returns: + new_waveform: (channels_num, audio_samples) + """ + ndim = waveform.shape[0] + + if ndim == 1: + return waveform + else: + random_axes = self.random_state.permutation(ndim) + return waveform[random_axes, :] + + def flip_axis(self, waveform: np.array, source_type: str) -> np.array: + r"""Randomly flip the waveform along x-axis. + + Args: + waveform: (channels_num, audio_samples) + source_type: str + + Returns: + new_waveform: (channels_num, audio_samples) + """ + ndim = waveform.shape[0] + random_values = self.random_state.choice([-1, 1], size=ndim) + + return waveform * random_values[:, None] diff --git a/bytesep/data/batch_data_preprocessors.py b/bytesep/data/batch_data_preprocessors.py new file mode 100644 index 0000000000000000000000000000000000000000..6fafa5ee6a999be3fb9ed467ef11d1021323cc79 --- /dev/null +++ b/bytesep/data/batch_data_preprocessors.py @@ -0,0 +1,141 @@ +from typing import Dict, List + +import torch + + +class BasicBatchDataPreprocessor: + def __init__(self, target_source_types: List[str]): + r"""Batch data preprocessor. Used for preparing mixtures and targets for + training. If there are multiple target source types, the waveforms of + those sources will be stacked along the channel dimension. + + Args: + target_source_types: List[str], e.g., ['vocals', 'bass', ...] + """ + self.target_source_types = target_source_types + + def __call__(self, batch_data_dict: Dict) -> List[Dict]: + r"""Format waveforms and targets for training. + + Args: + batch_data_dict: dict, e.g., { + 'mixture': (batch_size, channels_num, segment_samples), + 'vocals': (batch_size, channels_num, segment_samples), + 'bass': (batch_size, channels_num, segment_samples), + ..., + } + + Returns: + input_dict: dict, e.g., { + 'waveform': (batch_size, channels_num, segment_samples), + } + output_dict: dict, e.g., { + 'target': (batch_size, target_sources_num * channels_num, segment_samples) + } + """ + mixtures = batch_data_dict['mixture'] + # mixtures: (batch_size, channels_num, segment_samples) + + # Concatenate waveforms of multiple targets along the channel axis. + targets = torch.cat( + [batch_data_dict[source_type] for source_type in self.target_source_types], + dim=1, + ) + # targets: (batch_size, target_sources_num * channels_num, segment_samples) + + input_dict = {'waveform': mixtures} + target_dict = {'waveform': targets} + + return input_dict, target_dict + + +class ConditionalSisoBatchDataPreprocessor: + def __init__(self, target_source_types: List[str]): + r"""Conditional single input single output (SISO) batch data + preprocessor. Select one target source from several target sources as + training target and prepare the corresponding conditional vector. + + Args: + target_source_types: List[str], e.g., ['vocals', 'bass', ...] + """ + self.target_source_types = target_source_types + + def __call__(self, batch_data_dict: Dict) -> List[Dict]: + r"""Format waveforms and targets for training. + + Args: + batch_data_dict: dict, e.g., { + 'mixture': (batch_size, channels_num, segment_samples), + 'vocals': (batch_size, channels_num, segment_samples), + 'bass': (batch_size, channels_num, segment_samples), + ..., + } + + Returns: + input_dict: dict, e.g., { + 'waveform': (batch_size, channels_num, segment_samples), + 'condition': (batch_size, target_sources_num), + } + output_dict: dict, e.g., { + 'target': (batch_size, channels_num, segment_samples) + } + """ + + batch_size = len(batch_data_dict['mixture']) + target_sources_num = len(self.target_source_types) + + assert ( + batch_size % target_sources_num == 0 + ), "Batch size should be \ + evenly divided by target sources number." + + mixtures = batch_data_dict['mixture'] + # mixtures: (batch_size, channels_num, segment_samples) + + conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device) + # conditions: (batch_size, target_sources_num) + + targets = [] + + for n in range(batch_size): + + k = n % target_sources_num # source class index + source_type = self.target_source_types[k] + + targets.append(batch_data_dict[source_type][n]) + + conditions[n, k] = 1 + + # conditions will looks like: + # [[1, 0, 0, 0], + # [0, 1, 0, 0], + # [0, 0, 1, 0], + # [0, 0, 0, 1], + # [1, 0, 0, 0], + # [0, 1, 0, 0], + # ..., + # ] + + targets = torch.stack(targets, dim=0) + # targets: (batch_size, channels_num, segment_samples) + + input_dict = { + 'waveform': mixtures, + 'condition': conditions, + } + + target_dict = {'waveform': targets} + + return input_dict, target_dict + + +def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object: + r"""Get batch data preprocessor class.""" + if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor': + return BasicBatchDataPreprocessor + + elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor': + return ConditionalSisoBatchDataPreprocessor + + else: + raise NotImplementedError diff --git a/bytesep/data/data_modules.py b/bytesep/data/data_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..e37b4109f8b915ea864b19795374038184388308 --- /dev/null +++ b/bytesep/data/data_modules.py @@ -0,0 +1,187 @@ +from typing import Dict, List, NoReturn, Optional + +import h5py +import librosa +import numpy as np +import torch +from pytorch_lightning.core.datamodule import LightningDataModule + +from bytesep.data.samplers import DistributedSamplerWrapper +from bytesep.utils import int16_to_float32 + + +class DataModule(LightningDataModule): + def __init__( + self, + train_sampler: object, + train_dataset: object, + num_workers: int, + distributed: bool, + ): + r"""Data module. + + Args: + train_sampler: Sampler object + train_dataset: Dataset object + num_workers: int + distributed: bool + """ + super().__init__() + self._train_sampler = train_sampler + self.train_dataset = train_dataset + self.num_workers = num_workers + self.distributed = distributed + + def setup(self, stage: Optional[str] = None) -> NoReturn: + r"""called on every device.""" + + # SegmentSampler is used for selecting segments for training. + # On multiple devices, each SegmentSampler samples a part of mini-batch + # data. + if self.distributed: + self.train_sampler = DistributedSamplerWrapper(self._train_sampler) + + else: + self.train_sampler = self._train_sampler + + def train_dataloader(self) -> torch.utils.data.DataLoader: + r"""Get train loader.""" + train_loader = torch.utils.data.DataLoader( + dataset=self.train_dataset, + batch_sampler=self.train_sampler, + collate_fn=collate_fn, + num_workers=self.num_workers, + pin_memory=True, + ) + + return train_loader + + +class Dataset: + def __init__(self, augmentor: object, segment_samples: int): + r"""Used for getting data according to a meta. + + Args: + augmentor: Augmentor class + segment_samples: int + """ + self.augmentor = augmentor + self.segment_samples = segment_samples + + def __getitem__(self, meta: Dict) -> Dict: + r"""Return data according to a meta. E.g., an input meta looks like: { + 'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]], + 'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}. + } + + Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation). + Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation). + Finally, mixture is created by summing vocals and accompaniment. + + Args: + meta: dict, e.g., { + 'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]], + 'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]} + } + + Returns: + data_dict: dict, e.g., { + 'vocals': (channels, segments_num), + 'accompaniment': (channels, segments_num), + 'mixture': (channels, segments_num), + } + """ + source_types = meta.keys() + data_dict = {} + + for source_type in source_types: + # E.g., ['vocals', 'bass', ...] + + waveforms = [] # Audio segments to be mix-audio augmented. + + for m in meta[source_type]: + # E.g., { + # 'hdf5_path': '.../song_A.h5', + # 'key_in_hdf5': 'vocals', + # 'begin_sample': '13406400', + # 'end_sample': 13538700, + # } + + hdf5_path = m['hdf5_path'] + key_in_hdf5 = m['key_in_hdf5'] + bgn_sample = m['begin_sample'] + end_sample = m['end_sample'] + + with h5py.File(hdf5_path, 'r') as hf: + + if source_type == 'audioset': + index_in_hdf5 = m['index_in_hdf5'] + waveform = int16_to_float32( + hf['waveform'][index_in_hdf5][bgn_sample:end_sample] + ) + waveform = waveform[None, :] + else: + waveform = int16_to_float32( + hf[key_in_hdf5][:, bgn_sample:end_sample] + ) + + if self.augmentor: + waveform = self.augmentor(waveform, source_type) + + waveform = librosa.util.fix_length( + waveform, size=self.segment_samples, axis=1 + ) + # (channels_num, segments_num) + + waveforms.append(waveform) + # E.g., waveforms: [(channels_num, audio_samples), (channels_num, audio_samples)] + + # mix-audio augmentation + data_dict[source_type] = np.sum(waveforms, axis=0) + # data_dict[source_type]: (channels_num, audio_samples) + + # data_dict looks like: { + # 'voclas': (channels_num, audio_samples), + # 'accompaniment': (channels_num, audio_samples) + # } + + # Mix segments from different sources. + mixture = np.sum( + [data_dict[source_type] for source_type in source_types], axis=0 + ) + data_dict['mixture'] = mixture + # shape: (channels_num, audio_samples) + + return data_dict + + +def collate_fn(list_data_dict: List[Dict]) -> Dict: + r"""Collate mini-batch data to inputs and targets for training. + + Args: + list_data_dict: e.g., [ + {'vocals': (channels_num, segment_samples), + 'accompaniment': (channels_num, segment_samples), + 'mixture': (channels_num, segment_samples) + }, + {'vocals': (channels_num, segment_samples), + 'accompaniment': (channels_num, segment_samples), + 'mixture': (channels_num, segment_samples) + }, + ...] + + Returns: + data_dict: e.g. { + 'vocals': (batch_size, channels_num, segment_samples), + 'accompaniment': (batch_size, channels_num, segment_samples), + 'mixture': (batch_size, channels_num, segment_samples) + } + """ + data_dict = {} + + for key in list_data_dict[0].keys(): + data_dict[key] = torch.Tensor( + np.array([data_dict[key] for data_dict in list_data_dict]) + ) + + return data_dict diff --git a/bytesep/data/samplers.py b/bytesep/data/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..6b3cf99ecdf7f5e392da7b0cd2cc88ee4b8c90b5 --- /dev/null +++ b/bytesep/data/samplers.py @@ -0,0 +1,188 @@ +import pickle +from typing import Dict, List, NoReturn + +import numpy as np +import torch.distributed as dist + + +class SegmentSampler: + def __init__( + self, + indexes_path: str, + segment_samples: int, + mixaudio_dict: Dict, + batch_size: int, + steps_per_epoch: int, + random_seed=1234, + ): + r"""Sample training indexes of sources. + + Args: + indexes_path: str, path of indexes dict + segment_samplers: int + mixaudio_dict, dict, including hyper-parameters for mix-audio data + augmentation, e.g., {'voclas': 2, 'accompaniment': 2} + batch_size: int + steps_per_epoch: int, #steps_per_epoch is called an `epoch` + random_seed: int + """ + self.segment_samples = segment_samples + self.mixaudio_dict = mixaudio_dict + self.batch_size = batch_size + self.steps_per_epoch = steps_per_epoch + + self.meta_dict = pickle.load(open(indexes_path, "rb")) + # E.g., { + # 'vocals': [ + # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}, + # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410}, + # ... + # ], + # 'accompaniment': [ + # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}, + # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410}, + # ... + # ] + # } + + self.source_types = self.meta_dict.keys() + # E.g., ['vocals', 'accompaniment'] + + self.pointers_dict = {source_type: 0 for source_type in self.source_types} + # E.g., {'vocals': 0, 'accompaniment': 0} + + self.indexes_dict = { + source_type: np.arange(len(self.meta_dict[source_type])) + for source_type in self.source_types + } + # E.g. { + # 'vocals': [0, 1, ..., 225751], + # 'accompaniment': [0, 1, ..., 225751] + # } + + self.random_state = np.random.RandomState(random_seed) + + # Shuffle indexes. + for source_type in self.source_types: + self.random_state.shuffle(self.indexes_dict[source_type]) + print("{}: {}".format(source_type, len(self.indexes_dict[source_type]))) + + def __iter__(self) -> List[Dict]: + r"""Yield a batch of meta info. + + Returns: + batch_meta_list: (batch_size,) e.g., when mix-audio is 2, looks like [ + {'vocals': [ + {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700}, + {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}] + 'accompaniment': [ + {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 14579460, 'end_sample': 14711760}, + {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 3995460, 'end_sample': 4127760}] + } + ... + ] + """ + batch_size = self.batch_size + + while True: + batch_meta_dict = {source_type: [] for source_type in self.source_types} + + for source_type in self.source_types: + # E.g., ['vocals', 'accompaniment'] + + # Loop until get a mini-batch. + while len(batch_meta_dict[source_type]) != batch_size: + + largest_index = ( + len(self.indexes_dict[source_type]) + - self.mixaudio_dict[source_type] + ) + # E.g., 225750 = 225752 - 2 + + if self.pointers_dict[source_type] > largest_index: + + # Reset pointer, and shuffle indexes. + self.pointers_dict[source_type] = 0 + self.random_state.shuffle(self.indexes_dict[source_type]) + + source_metas = [] + mix_audios_num = self.mixaudio_dict[source_type] + + for _ in range(mix_audios_num): + + pointer = self.pointers_dict[source_type] + # E.g., 1 + + index = self.indexes_dict[source_type][pointer] + # E.g., 12231 + + self.pointers_dict[source_type] += 1 + + source_meta = self.meta_dict[source_type][index] + # E.g., ['song_A.h5', 198450, 330750] + + # source_metas.append(new_source_meta) + source_metas.append(source_meta) + + batch_meta_dict[source_type].append(source_metas) + # When mix-audio is 2, batch_meta_dict looks like: { + # 'vocals': [ + # [{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700}, + # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}], + # [{'hdf5_path': 'songC.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1186290, 'end_sample': 1318590}, + # {'hdf5_path': 'songD.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 8462790, 'end_sample': 8595090}] + # ] + # 'accompaniment': [ + # [{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 24232950, 'end_sample': 24365250}, + # {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1569960, 'end_sample': 1702260}], + # [{'hdf5_path': 'songG.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 2795940, 'end_sample': 2928240}, + # {'hdf5_path': 'songH.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 10923570, 'end_sample': 11055870}] + # ] + # } + + batch_meta_list = [ + { + source_type: batch_meta_dict[source_type][i] + for source_type in self.source_types + } + for i in range(batch_size) + ] + # When mix-audio is 2, batch_meta_list looks like: [ + # {'vocals': [ + # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700}, + # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}] + # 'accompaniment': [ + # {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 14579460, 'end_sample': 14711760}, + # {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 3995460, 'end_sample': 4127760}] + # } + # ... + # ] + + yield batch_meta_list + + def __len__(self) -> int: + return self.steps_per_epoch + + def state_dict(self) -> Dict: + state = {'pointers_dict': self.pointers_dict, 'indexes_dict': self.indexes_dict} + return state + + def load_state_dict(self, state) -> NoReturn: + self.pointers_dict = state['pointers_dict'] + self.indexes_dict = state['indexes_dict'] + + +class DistributedSamplerWrapper: + def __init__(self, sampler): + r"""Distributed wrapper of sampler.""" + self.sampler = sampler + + def __iter__(self): + num_replicas = dist.get_world_size() + rank = dist.get_rank() + + for indices in self.sampler: + yield indices[rank::num_replicas] + + def __len__(self) -> int: + return len(self.sampler) diff --git a/bytesep/dataset_creation/__init__.py b/bytesep/dataset_creation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bytesep/dataset_creation/create_evaluation_audios/__init__.py b/bytesep/dataset_creation/create_evaluation_audios/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py b/bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py new file mode 100644 index 0000000000000000000000000000000000000000..1b632e58765aa2a3e1eeadc4c98183919b3bf247 --- /dev/null +++ b/bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py @@ -0,0 +1,160 @@ +import argparse +import os +from typing import NoReturn + +import librosa +import numpy as np +import soundfile + +from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import ( + read_csv as read_instruments_solo_csv, +) +from bytesep.dataset_creation.pack_audios_to_hdf5s.maestro import ( + read_csv as read_maestro_csv, +) +from bytesep.utils import load_random_segment + + +def create_evaluation(args) -> NoReturn: + r"""Random mix and write out audios for evaluation. + + Args: + piano_dataset_dir: str, the directory of the piano dataset + symphony_dataset_dir: str, the directory of the symphony dataset + evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments + sample_rate: int + channels: int, e.g., 1 | 2 + evaluation_segments_num: int + mono: bool + + Returns: + NoReturn + """ + + # arguments & parameters + piano_dataset_dir = args.piano_dataset_dir + symphony_dataset_dir = args.symphony_dataset_dir + evaluation_audios_dir = args.evaluation_audios_dir + sample_rate = args.sample_rate + channels = args.channels + evaluation_segments_num = args.evaluation_segments_num + mono = True if channels == 1 else False + + split = 'test' + segment_seconds = 10.0 + + random_state = np.random.RandomState(1234) + + piano_meta_csv = os.path.join(piano_dataset_dir, 'maestro-v2.0.0.csv') + piano_names_dict = read_maestro_csv(piano_meta_csv) + piano_audio_names = piano_names_dict[split] + + symphony_meta_csv = os.path.join(symphony_dataset_dir, 'validation.csv') + symphony_names_dict = read_instruments_solo_csv(symphony_meta_csv) + symphony_audio_names = symphony_names_dict[split] + + for source_type in ['piano', 'symphony', 'mixture']: + output_dir = os.path.join(evaluation_audios_dir, split, source_type) + os.makedirs(output_dir, exist_ok=True) + + for n in range(evaluation_segments_num): + + print('{} / {}'.format(n, evaluation_segments_num)) + + # Randomly select and write out a clean piano segment. + piano_audio_name = random_state.choice(piano_audio_names) + piano_audio_path = os.path.join(piano_dataset_dir, piano_audio_name) + + piano_audio = load_random_segment( + audio_path=piano_audio_path, + random_state=random_state, + segment_seconds=segment_seconds, + mono=mono, + sample_rate=sample_rate, + ) + + output_piano_path = os.path.join( + evaluation_audios_dir, split, 'piano', '{:04d}.wav'.format(n) + ) + soundfile.write( + file=output_piano_path, data=piano_audio.T, samplerate=sample_rate + ) + print("Write out to {}".format(output_piano_path)) + + # Randomly select and write out a clean symphony segment. + symphony_audio_name = random_state.choice(symphony_audio_names) + symphony_audio_path = os.path.join( + symphony_dataset_dir, "mp3s", symphony_audio_name + ) + + symphony_audio = load_random_segment( + audio_path=symphony_audio_path, + random_state=random_state, + segment_seconds=segment_seconds, + mono=mono, + sample_rate=sample_rate, + ) + + output_symphony_path = os.path.join( + evaluation_audios_dir, split, 'symphony', '{:04d}.wav'.format(n) + ) + soundfile.write( + file=output_symphony_path, data=symphony_audio.T, samplerate=sample_rate + ) + print("Write out to {}".format(output_symphony_path)) + + # Mix piano and symphony segments and write out a mixture segment. + mixture_audio = symphony_audio + piano_audio + output_mixture_path = os.path.join( + evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n) + ) + soundfile.write( + file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate + ) + print("Write out to {}".format(output_mixture_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--piano_dataset_dir", + type=str, + required=True, + help="The directory of the piano dataset.", + ) + parser.add_argument( + "--symphony_dataset_dir", + type=str, + required=True, + help="The directory of the symphony dataset.", + ) + parser.add_argument( + "--evaluation_audios_dir", + type=str, + required=True, + help="The directory to write out randomly selected and mixed audio segments.", + ) + parser.add_argument( + "--sample_rate", + type=int, + required=True, + help="Sample rate.", + ) + parser.add_argument( + "--channels", + type=int, + required=True, + help="Audio channels, e.g, 1 or 2.", + ) + parser.add_argument( + "--evaluation_segments_num", + type=int, + required=True, + help="The number of segments to create for evaluation.", + ) + + # Parse arguments. + args = parser.parse_args() + + create_evaluation(args) diff --git a/bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py b/bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py new file mode 100644 index 0000000000000000000000000000000000000000..8e337feaa304f09b21fc400dfffd9c77a9961074 --- /dev/null +++ b/bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py @@ -0,0 +1,164 @@ +import argparse +import os +import soundfile +from typing import NoReturn + +import musdb +import numpy as np + +from bytesep.utils import load_audio + + +def create_evaluation(args) -> NoReturn: + r"""Random mix and write out audios for evaluation. + + Args: + vctk_dataset_dir: str, the directory of the VCTK dataset + symphony_dataset_dir: str, the directory of the symphony dataset + evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments + sample_rate: int + channels: int, e.g., 1 | 2 + evaluation_segments_num: int + mono: bool + + Returns: + NoReturn + """ + + # arguments & parameters + vctk_dataset_dir = args.vctk_dataset_dir + musdb18_dataset_dir = args.musdb18_dataset_dir + evaluation_audios_dir = args.evaluation_audios_dir + sample_rate = args.sample_rate + channels = args.channels + evaluation_segments_num = args.evaluation_segments_num + mono = True if channels == 1 else False + + split = 'test' + random_state = np.random.RandomState(1234) + + # paths + audios_dir = os.path.join(vctk_dataset_dir, "wav48", split) + + for source_type in ['speech', 'music', 'mixture']: + output_dir = os.path.join(evaluation_audios_dir, split, source_type) + os.makedirs(output_dir, exist_ok=True) + + # Get VCTK audio paths. + speech_audio_paths = [] + speaker_ids = sorted(os.listdir(audios_dir)) + + for speaker_id in speaker_ids: + speaker_audios_dir = os.path.join(audios_dir, speaker_id) + + audio_names = sorted(os.listdir(speaker_audios_dir)) + + for audio_name in audio_names: + speaker_audio_path = os.path.join(speaker_audios_dir, audio_name) + speech_audio_paths.append(speaker_audio_path) + + # Get Musdb18 audio paths. + mus = musdb.DB(root=musdb18_dataset_dir, subsets=[split]) + track_indexes = np.arange(len(mus.tracks)) + + for n in range(evaluation_segments_num): + + print('{} / {}'.format(n, evaluation_segments_num)) + + # Randomly select and write out a clean speech segment. + speech_audio_path = random_state.choice(speech_audio_paths) + + speech_audio = load_audio( + audio_path=speech_audio_path, mono=mono, sample_rate=sample_rate + ) + # (channels_num, audio_samples) + + if channels == 2: + speech_audio = np.tile(speech_audio, (2, 1)) + # (channels_num, audio_samples) + + output_speech_path = os.path.join( + evaluation_audios_dir, split, 'speech', '{:04d}.wav'.format(n) + ) + soundfile.write( + file=output_speech_path, data=speech_audio.T, samplerate=sample_rate + ) + print("Write out to {}".format(output_speech_path)) + + # Randomly select and write out a clean music segment. + track_index = random_state.choice(track_indexes) + track = mus[track_index] + + segment_samples = speech_audio.shape[1] + start_sample = int( + random_state.uniform(0.0, segment_samples - speech_audio.shape[1]) + ) + + music_audio = track.audio[start_sample : start_sample + segment_samples, :].T + # (channels_num, audio_samples) + + output_music_path = os.path.join( + evaluation_audios_dir, split, 'music', '{:04d}.wav'.format(n) + ) + soundfile.write( + file=output_music_path, data=music_audio.T, samplerate=sample_rate + ) + print("Write out to {}".format(output_music_path)) + + # Mix speech and music segments and write out a mixture segment. + mixture_audio = speech_audio + music_audio + # (channels_num, audio_samples) + + output_mixture_path = os.path.join( + evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n) + ) + soundfile.write( + file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate + ) + print("Write out to {}".format(output_mixture_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--vctk_dataset_dir", + type=str, + required=True, + help="The directory of the VCTK dataset.", + ) + parser.add_argument( + "--musdb18_dataset_dir", + type=str, + required=True, + help="The directory of the MUSDB18 dataset.", + ) + parser.add_argument( + "--evaluation_audios_dir", + type=str, + required=True, + help="The directory to write out randomly selected and mixed audio segments.", + ) + parser.add_argument( + "--sample_rate", + type=int, + required=True, + help="Sample rate", + ) + parser.add_argument( + "--channels", + type=int, + required=True, + help="Audio channels, e.g, 1 or 2.", + ) + parser.add_argument( + "--evaluation_segments_num", + type=int, + required=True, + help="The number of segments to create for evaluation.", + ) + + # Parse arguments. + args = parser.parse_args() + + create_evaluation(args) diff --git a/bytesep/dataset_creation/create_evaluation_audios/violin-piano.py b/bytesep/dataset_creation/create_evaluation_audios/violin-piano.py new file mode 100644 index 0000000000000000000000000000000000000000..da36f43553f507c1b980fff826d443cdec113aa6 --- /dev/null +++ b/bytesep/dataset_creation/create_evaluation_audios/violin-piano.py @@ -0,0 +1,162 @@ +import argparse +import os +from typing import NoReturn + +import librosa +import numpy as np +import soundfile + +from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import ( + read_csv as read_instruments_solo_csv, +) +from bytesep.dataset_creation.pack_audios_to_hdf5s.maestro import ( + read_csv as read_maestro_csv, +) +from bytesep.utils import load_random_segment + + +def create_evaluation(args) -> NoReturn: + r"""Random mix and write out audios for evaluation. + + Args: + violin_dataset_dir: str, the directory of the violin dataset + piano_dataset_dir: str, the directory of the piano dataset + evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments + sample_rate: int + channels: int, e.g., 1 | 2 + evaluation_segments_num: int + mono: bool + + Returns: + NoReturn + """ + + # arguments & parameters + violin_dataset_dir = args.violin_dataset_dir + piano_dataset_dir = args.piano_dataset_dir + evaluation_audios_dir = args.evaluation_audios_dir + sample_rate = args.sample_rate + channels = args.channels + evaluation_segments_num = args.evaluation_segments_num + mono = True if channels == 1 else False + + split = 'test' + segment_seconds = 10.0 + + random_state = np.random.RandomState(1234) + + violin_meta_csv = os.path.join(violin_dataset_dir, 'validation.csv') + violin_names_dict = read_instruments_solo_csv(violin_meta_csv) + violin_audio_names = violin_names_dict['{}'.format(split)] + + piano_meta_csv = os.path.join(piano_dataset_dir, 'maestro-v2.0.0.csv') + piano_names_dict = read_maestro_csv(piano_meta_csv) + piano_audio_names = piano_names_dict['{}'.format(split)] + + for source_type in ['violin', 'piano', 'mixture']: + output_dir = os.path.join(evaluation_audios_dir, split, source_type) + os.makedirs(output_dir, exist_ok=True) + + for n in range(evaluation_segments_num): + + print('{} / {}'.format(n, evaluation_segments_num)) + + # Randomly select and write out a clean violin segment. + violin_audio_name = random_state.choice(violin_audio_names) + violin_audio_path = os.path.join(violin_dataset_dir, "mp3s", violin_audio_name) + + violin_audio = load_random_segment( + audio_path=violin_audio_path, + random_state=random_state, + segment_seconds=segment_seconds, + mono=mono, + sample_rate=sample_rate, + ) + # (channels_num, audio_samples) + + output_violin_path = os.path.join( + evaluation_audios_dir, split, 'violin', '{:04d}.wav'.format(n) + ) + soundfile.write( + file=output_violin_path, data=violin_audio.T, samplerate=sample_rate + ) + print("Write out to {}".format(output_violin_path)) + + # Randomly select and write out a clean piano segment. + piano_audio_name = random_state.choice(piano_audio_names) + piano_audio_path = os.path.join(piano_dataset_dir, piano_audio_name) + + piano_audio = load_random_segment( + audio_path=piano_audio_path, + random_state=random_state, + segment_seconds=segment_seconds, + mono=mono, + sample_rate=sample_rate, + ) + # (channels_num, audio_samples) + + output_piano_path = os.path.join( + evaluation_audios_dir, split, 'piano', '{:04d}.wav'.format(n) + ) + soundfile.write( + file=output_piano_path, data=piano_audio.T, samplerate=sample_rate + ) + print("Write out to {}".format(output_piano_path)) + + # Mix violin and piano segments and write out a mixture segment. + mixture_audio = violin_audio + piano_audio + # (channels_num, audio_samples) + + output_mixture_path = os.path.join( + evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n) + ) + soundfile.write( + file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate + ) + print("Write out to {}".format(output_mixture_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--violin_dataset_dir", + type=str, + required=True, + help="The directory of the violin dataset.", + ) + parser.add_argument( + "--piano_dataset_dir", + type=str, + required=True, + help="The directory of the piano dataset.", + ) + parser.add_argument( + "--evaluation_audios_dir", + type=str, + required=True, + help="The directory to write out randomly selected and mixed audio segments.", + ) + parser.add_argument( + "--sample_rate", + type=int, + required=True, + help="Sample rate", + ) + parser.add_argument( + "--channels", + type=int, + required=True, + help="Audio channels, e.g, 1 or 2.", + ) + parser.add_argument( + "--evaluation_segments_num", + type=int, + required=True, + help="The number of segments to create for evaluation.", + ) + + # Parse arguments. + args = parser.parse_args() + + create_evaluation(args) diff --git a/bytesep/dataset_creation/create_indexes/__init__.py b/bytesep/dataset_creation/create_indexes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bytesep/dataset_creation/create_indexes/create_indexes.py b/bytesep/dataset_creation/create_indexes/create_indexes.py new file mode 100644 index 0000000000000000000000000000000000000000..fdfac4e3370e06d69904f99f5852a4c9e824389b --- /dev/null +++ b/bytesep/dataset_creation/create_indexes/create_indexes.py @@ -0,0 +1,142 @@ +import argparse +import os +import pickle +from typing import NoReturn + +import h5py + +from bytesep.utils import read_yaml + + +def create_indexes(args) -> NoReturn: + r"""Create and write out training indexes into disk. The indexes may contain + information from multiple datasets. During training, training indexes will + be shuffled and iterated for selecting segments to be mixed. E.g., the + training indexes_dict looks like: { + 'vocals': [ + {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300} + {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710} + ... + ] + 'accompaniment': [ + {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300} + {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710} + ... + ] + } + """ + + # Arugments & parameters + workspace = args.workspace + config_yaml = args.config_yaml + + # Only create indexes for training, because evalution is on entire pieces. + split = "train" + + # Read config file. + configs = read_yaml(config_yaml) + + sample_rate = configs["sample_rate"] + segment_samples = int(configs["segment_seconds"] * sample_rate) + + # Path to write out index. + indexes_path = os.path.join(workspace, configs[split]["indexes"]) + os.makedirs(os.path.dirname(indexes_path), exist_ok=True) + + source_types = configs[split]["source_types"].keys() + # E.g., ['vocals', 'accompaniment'] + + indexes_dict = {source_type: [] for source_type in source_types} + # E.g., indexes_dict will looks like: { + # 'vocals': [ + # {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300} + # {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710} + # ... + # ] + # 'accompaniment': [ + # {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300} + # {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710} + # ... + # ] + # } + + # Get training indexes for each source type. + for source_type in source_types: + # E.g., ['vocals', 'bass', ...] + + print("--- {} ---".format(source_type)) + + dataset_types = configs[split]["source_types"][source_type] + # E.g., ['musdb18', ...] + + # Each source can come from mulitple datasets. + for dataset_type in dataset_types: + + hdf5s_dir = os.path.join( + workspace, dataset_types[dataset_type]["hdf5s_directory"] + ) + + hop_samples = int(dataset_types[dataset_type]["hop_seconds"] * sample_rate) + + key_in_hdf5 = dataset_types[dataset_type]["key_in_hdf5"] + # E.g., 'vocals' + + hdf5_names = sorted(os.listdir(hdf5s_dir)) + print("Hdf5 files num: {}".format(len(hdf5_names))) + + # Traverse all packed hdf5 files of a dataset. + for n, hdf5_name in enumerate(hdf5_names): + + print(n, hdf5_name) + hdf5_path = os.path.join(hdf5s_dir, hdf5_name) + + with h5py.File(hdf5_path, "r") as hf: + + bgn_sample = 0 + while bgn_sample + segment_samples < hf[key_in_hdf5].shape[-1]: + meta = { + 'hdf5_path': hdf5_path, + 'key_in_hdf5': key_in_hdf5, + 'begin_sample': bgn_sample, + 'end_sample': bgn_sample + segment_samples, + } + indexes_dict[source_type].append(meta) + + bgn_sample += hop_samples + + # If the audio length is shorter than the segment length, + # then use the entire audio as a segment. + if bgn_sample == 0: + meta = { + 'hdf5_path': hdf5_path, + 'key_in_hdf5': key_in_hdf5, + 'begin_sample': 0, + 'end_sample': segment_samples, + } + indexes_dict[source_type].append(meta) + + print( + "Total indexes for {}: {}".format( + source_type, len(indexes_dict[source_type]) + ) + ) + + pickle.dump(indexes_dict, open(indexes_path, "wb")) + print("Write index dict to {}".format(indexes_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--workspace", type=str, required=True, help="Directory of workspace." + ) + parser.add_argument( + "--config_yaml", type=str, required=True, help="User defined config file." + ) + + # Parse arguments. + args = parser.parse_args() + + # Create training indexes. + create_indexes(args) diff --git a/bytesep/dataset_creation/pack_audios_to_hdf5s/__init__.py b/bytesep/dataset_creation/pack_audios_to_hdf5s/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py b/bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py new file mode 100644 index 0000000000000000000000000000000000000000..18ac5a9ccb23c1f606d0052aae8818dda5e988c0 --- /dev/null +++ b/bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py @@ -0,0 +1,173 @@ +import argparse +import os +import pathlib +import time +from concurrent.futures import ProcessPoolExecutor +from typing import Dict, List, NoReturn + +import h5py +import numpy as np +import pandas as pd + +from bytesep.utils import float32_to_int16, load_audio + + +def read_csv(meta_csv) -> Dict: + r"""Get train & test names from csv. + + Args: + meta_csv: str + + Returns: + names_dict: dict, e.g., { + 'train', ['songA.mp3', 'songB.mp3', ...], + 'test': ['songE.mp3', 'songF.mp3', ...] + } + """ + df = pd.read_csv(meta_csv, sep=',') + + names_dict = {} + + for split in ['train', 'test']: + audio_indexes = df['split'] == split + audio_names = list(df['audio_name'][audio_indexes]) + audio_names = [ + '{}.mp3'.format(pathlib.Path(audio_name).stem) for audio_name in audio_names + ] + names_dict[split] = audio_names + + return names_dict + + +def pack_audios_to_hdf5s(args) -> NoReturn: + r"""Pack (resampled) audio files into hdf5 files to speed up loading. + + Args: + dataset_dir: str + split: str, 'train' | 'test' + source_type: str + hdf5s_dir: str, directory to write out hdf5 files + sample_rate: int + channels_num: int + mono: bool + + Returns: + NoReturn + """ + + # arguments & parameters + dataset_dir = args.dataset_dir + split = args.split + source_type = args.source_type + hdf5s_dir = args.hdf5s_dir + sample_rate = args.sample_rate + channels = args.channels + mono = True if channels == 1 else False + + # Only pack data for training data. + assert split == "train" + + # paths + audios_dir = os.path.join(dataset_dir, 'mp3s') + meta_csv = os.path.join(dataset_dir, 'validation.csv') + + os.makedirs(hdf5s_dir, exist_ok=True) + + # Read train & test names. + names_dict = read_csv(meta_csv) + + audio_names = names_dict[split] + + params = [] + + for audio_index, audio_name in enumerate(audio_names): + + audio_path = os.path.join(audios_dir, audio_name) + + hdf5_path = os.path.join( + hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem) + ) + + param = ( + audio_index, + audio_name, + source_type, + audio_path, + mono, + sample_rate, + hdf5_path, + ) + params.append(param) + + # Uncomment for debug. + # write_single_audio_to_hdf5(params[0]) + # os._exit() + + pack_hdf5s_time = time.time() + + with ProcessPoolExecutor(max_workers=None) as pool: + # Maximum works on the machine + pool.map(write_single_audio_to_hdf5, params) + + print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time)) + + +def write_single_audio_to_hdf5(param: List) -> NoReturn: + r"""Write single audio into hdf5 file.""" + + ( + audio_index, + audio_name, + source_type, + audio_path, + mono, + sample_rate, + hdf5_path, + ) = param + + with h5py.File(hdf5_path, "w") as hf: + + hf.attrs.create("audio_name", data=audio_name, dtype="S100") + hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32) + + audio = load_audio(audio_path=audio_path, mono=mono, sample_rate=sample_rate) + # audio: (channels_num, audio_samples) + + hf.create_dataset( + name=source_type, data=float32_to_int16(audio), dtype=np.int16 + ) + + print('{} Write hdf5 to {}'.format(audio_index, hdf5_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + help="Directory of the instruments solo dataset.", + ) + parser.add_argument("--split", type=str, required=True, choices=["train", "test"]) + parser.add_argument( + "--source_type", + type=str, + required=True, + ) + parser.add_argument( + "--hdf5s_dir", + type=str, + required=True, + help="Directory to write out hdf5 files.", + ) + parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.") + parser.add_argument( + "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo." + ) + + # Parse arguments. + args = parser.parse_args() + + # Pack audios to hdf5 files. + pack_audios_to_hdf5s(args) diff --git a/bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py b/bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py new file mode 100644 index 0000000000000000000000000000000000000000..bb349816b19f2ca44e4f3c3e70d03da49b74cbc9 --- /dev/null +++ b/bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py @@ -0,0 +1,136 @@ +import argparse +import os +import pathlib +import time +from concurrent.futures import ProcessPoolExecutor +from typing import Dict, NoReturn + +import pandas as pd + +from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import ( + write_single_audio_to_hdf5, +) + + +def read_csv(meta_csv) -> Dict: + r"""Get train & test names from csv. + + Args: + meta_csv: str + + Returns: + names_dict: dict, e.g., { + 'train', ['a1.mp3', 'a2.mp3'], + 'test': ['b1.mp3', 'b2.mp3'] + } + """ + df = pd.read_csv(meta_csv, sep=',') + + names_dict = {} + + for split in ['train', 'test']: + audio_indexes = df['split'] == split + audio_names = list(df['audio_filename'][audio_indexes]) + names_dict[split] = audio_names + + return names_dict + + +def pack_audios_to_hdf5s(args) -> NoReturn: + r"""Pack (resampled) audio files into hdf5 files to speed up loading. + + Args: + dataset_dir: str + split: str, 'train' | 'test' + hdf5s_dir: str, directory to write out hdf5 files + sample_rate: int + channels_num: int + mono: bool + + Returns: + NoReturn + """ + + # arguments & parameters + dataset_dir = args.dataset_dir + split = args.split + hdf5s_dir = args.hdf5s_dir + sample_rate = args.sample_rate + channels = args.channels + mono = True if channels == 1 else False + + source_type = "piano" + + # Only pack data for training data. + assert split == "train" + + # paths + meta_csv = os.path.join(dataset_dir, 'maestro-v2.0.0.csv') + + os.makedirs(hdf5s_dir, exist_ok=True) + + # Read train & test names. + names_dict = read_csv(meta_csv) + + audio_names = names_dict['{}'.format(split)] + + params = [] + + for audio_index, audio_name in enumerate(audio_names): + + audio_path = os.path.join(dataset_dir, audio_name) + + hdf5_path = os.path.join( + hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem) + ) + + param = ( + audio_index, + audio_name, + source_type, + audio_path, + mono, + sample_rate, + hdf5_path, + ) + params.append(param) + + # Uncomment for debug. + # write_single_audio_to_hdf5(params[0]) + # os._exit(0) + + pack_hdf5s_time = time.time() + + with ProcessPoolExecutor(max_workers=None) as pool: + # Maximum works on the machine + pool.map(write_single_audio_to_hdf5, params) + + print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + help="Directory of the MAESTRO dataset.", + ) + parser.add_argument("--split", type=str, required=True, choices=["train", "test"]) + parser.add_argument( + "--hdf5s_dir", + type=str, + required=True, + help="Directory to write out hdf5 files.", + ) + parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.") + parser.add_argument( + "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo." + ) + + # Parse arguments. + args = parser.parse_args() + + # Pack audios to hdf5 files. + pack_audios_to_hdf5s(args) diff --git a/bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py b/bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py new file mode 100644 index 0000000000000000000000000000000000000000..4f242de04d850527531794a2a85f3454191adede --- /dev/null +++ b/bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py @@ -0,0 +1,207 @@ +import argparse +import os +import time +from concurrent.futures import ProcessPoolExecutor +from typing import NoReturn + +import h5py +import librosa +import musdb +import numpy as np + +from bytesep.utils import float32_to_int16 + +# Source types of the MUSDB18 dataset. +SOURCE_TYPES = ["vocals", "drums", "bass", "other", "accompaniment"] + + +def pack_audios_to_hdf5s(args) -> NoReturn: + r"""Pack (resampled) audio files into hdf5 files to speed up loading. + + Args: + dataset_dir: str + subset: str, 'train' | 'test' + split: str, '' | 'train' | 'valid' + hdf5s_dir: str, directory to write out hdf5 files + sample_rate: int + channels_num: int + mono: bool + + Returns: + NoReturn + """ + + # arguments & parameters + dataset_dir = args.dataset_dir + subset = args.subset + split = None if args.split == "" else args.split + hdf5s_dir = args.hdf5s_dir + sample_rate = args.sample_rate + channels = args.channels + + mono = True if channels == 1 else False + source_types = SOURCE_TYPES + resample_type = "kaiser_fast" + + # Paths + os.makedirs(hdf5s_dir, exist_ok=True) + + # Dataset of corresponding subset and split. + mus = musdb.DB(root=dataset_dir, subsets=[subset], split=split) + print("Subset: {}, Split: {}, Total pieces: {}".format(subset, split, len(mus))) + + params = [] # A list of params for multiple processing. + + for track_index in range(len(mus.tracks)): + + param = ( + dataset_dir, + subset, + split, + track_index, + source_types, + mono, + sample_rate, + resample_type, + hdf5s_dir, + ) + + params.append(param) + + # Uncomment for debug. + # write_single_audio_to_hdf5(params[0]) + # os._exit(0) + + pack_hdf5s_time = time.time() + + with ProcessPoolExecutor(max_workers=None) as pool: + # Maximum works on the machine + pool.map(write_single_audio_to_hdf5, params) + + print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time)) + + +def write_single_audio_to_hdf5(param) -> NoReturn: + r"""Write single audio into hdf5 file.""" + ( + dataset_dir, + subset, + split, + track_index, + source_types, + mono, + sample_rate, + resample_type, + hdf5s_dir, + ) = param + + # Dataset of corresponding subset and split. + mus = musdb.DB(root=dataset_dir, subsets=[subset], split=split) + track = mus.tracks[track_index] + + # Path to write out hdf5 file. + hdf5_path = os.path.join(hdf5s_dir, "{}.h5".format(track.name)) + + with h5py.File(hdf5_path, "w") as hf: + + hf.attrs.create("audio_name", data=track.name.encode(), dtype="S100") + hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32) + + for source_type in source_types: + + audio = track.targets[source_type].audio.T + # (channels_num, audio_samples) + + # Preprocess audio to mono / stereo, and resample. + audio = preprocess_audio( + audio, mono, track.rate, sample_rate, resample_type + ) + # audio = load_audio(audio_path=audio_path, mono=mono, sample_rate=sample_rate) + # (channels_num, audio_samples) | (audio_samples,) + + hf.create_dataset( + name=source_type, data=float32_to_int16(audio), dtype=np.int16 + ) + + # Mixture + audio = track.audio.T + # (channels_num, audio_samples) + + # Preprocess audio to mono / stereo, and resample. + audio = preprocess_audio(audio, mono, track.rate, sample_rate, resample_type) + # (channels_num, audio_samples) + + hf.create_dataset(name="mixture", data=float32_to_int16(audio), dtype=np.int16) + + print("{} Write to {}, {}".format(track_index, hdf5_path, audio.shape)) + + +def preprocess_audio(audio, mono, origin_sr, sr, resample_type) -> np.array: + r"""Preprocess audio to mono / stereo, and resample. + + Args: + audio: (channels_num, audio_samples), input audio + mono: bool + origin_sr: float, original sample rate + sr: float, target sample rate + resample_type: str, e.g., 'kaiser_fast' + + Returns: + output: ndarray, output audio + """ + if mono: + audio = np.mean(audio, axis=0) + # (audio_samples,) + + output = librosa.core.resample( + audio, orig_sr=origin_sr, target_sr=sr, res_type=resample_type + ) + # (audio_samples,) | (channels_num, audio_samples) + + if output.ndim == 1: + output = output[None, :] + # (1, audio_samples,) + + return output + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + help="Directory of the MUSDB18 dataset.", + ) + parser.add_argument( + "--subset", + type=str, + required=True, + choices=["train", "test"], + help="Train subset: 100 pieces; test subset: 50 pieces.", + ) + parser.add_argument( + "--split", + type=str, + required=True, + choices=["", "train", "valid"], + help="Use '' to use all 100 pieces to train. Use 'train' to use 86 \ + pieces for train, and use 'test' to use 14 pieces for valid.", + ) + parser.add_argument( + "--hdf5s_dir", + type=str, + required=True, + help="Directory to write out hdf5 files.", + ) + parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.") + parser.add_argument( + "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo." + ) + + # Parse arguments. + args = parser.parse_args() + + # Pack audios into hdf5 files. + pack_audios_to_hdf5s(args) diff --git a/bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py b/bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py new file mode 100644 index 0000000000000000000000000000000000000000..eb9c23c761ee41476605b04fe55d250bf61ac5a2 --- /dev/null +++ b/bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py @@ -0,0 +1,114 @@ +import argparse +import os +import pathlib +import time +from concurrent.futures import ProcessPoolExecutor +from typing import NoReturn + +from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import ( + write_single_audio_to_hdf5, +) + + +def pack_audios_to_hdf5s(args) -> NoReturn: + r"""Pack (resampled) audio files into hdf5 files to speed up loading. + + Args: + dataset_dir: str + split: str, 'train' | 'test' + hdf5s_dir: str, directory to write out hdf5 files + sample_rate: int + channels_num: int + mono: bool + + Returns: + NoReturn + """ + + # arguments & parameters + dataset_dir = args.dataset_dir + split = args.split + hdf5s_dir = args.hdf5s_dir + sample_rate = args.sample_rate + channels = args.channels + mono = True if channels == 1 else False + + source_type = "speech" + + # Only pack data for training data. + assert split == "train" + + audios_dir = os.path.join(dataset_dir, 'wav48', split) + os.makedirs(hdf5s_dir, exist_ok=True) + + speaker_ids = sorted(os.listdir(audios_dir)) + + params = [] + audio_index = 0 + + for speaker_id in speaker_ids: + + speaker_audios_dir = os.path.join(audios_dir, speaker_id) + + audio_names = sorted(os.listdir(speaker_audios_dir)) + + for audio_name in audio_names: + + audio_path = os.path.join(speaker_audios_dir, audio_name) + + hdf5_path = os.path.join( + hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem) + ) + + param = ( + audio_index, + audio_name, + source_type, + audio_path, + mono, + sample_rate, + hdf5_path, + ) + params.append(param) + + audio_index += 1 + + # Uncomment for debug. + # write_single_audio_to_hdf5(params[0]) + # os._exit(0) + + pack_hdf5s_time = time.time() + + with ProcessPoolExecutor(max_workers=None) as pool: + # Maximum works on the machine + pool.map(write_single_audio_to_hdf5, params) + + print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + help="Directory of the VCTK dataset.", + ) + parser.add_argument("--split", type=str, required=True, choices=["train", "test"]) + parser.add_argument( + "--hdf5s_dir", + type=str, + required=True, + help="Directory to write out hdf5 files.", + ) + parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.") + parser.add_argument( + "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo." + ) + + # Parse arguments. + args = parser.parse_args() + + # Pack audios into hdf5 files. + pack_audios_to_hdf5s(args) diff --git a/bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py b/bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py new file mode 100644 index 0000000000000000000000000000000000000000..7e166cea948c6458faa78740a8297112e17f74ec --- /dev/null +++ b/bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py @@ -0,0 +1,143 @@ +import argparse +import os +import pathlib +import time +from concurrent.futures import ProcessPoolExecutor +from typing import List, NoReturn + +import h5py +import numpy as np + +from bytesep.utils import float32_to_int16, load_audio + + +def pack_audios_to_hdf5s(args) -> NoReturn: + r"""Pack (resampled) audio files into hdf5 files to speed up loading. + + Args: + dataset_dir: str + split: str, 'train' | 'test' + hdf5s_dir: str, directory to write out hdf5 files + sample_rate: int + channels_num: int + mono: bool + + Returns: + NoReturn + """ + + # arguments & parameters + dataset_dir = args.dataset_dir + split = args.split + hdf5s_dir = args.hdf5s_dir + sample_rate = args.sample_rate + channels = args.channels + mono = True if channels == 1 else False + + # Only pack data for training data. + assert split == "train" + + speech_dir = os.path.join(dataset_dir, "clean_{}set_wav".format(split)) + mixture_dir = os.path.join(dataset_dir, "noisy_{}set_wav".format(split)) + + os.makedirs(hdf5s_dir, exist_ok=True) + + # Read names. + audio_names = sorted(os.listdir(speech_dir)) + + params = [] + + for audio_index, audio_name in enumerate(audio_names): + + speech_path = os.path.join(speech_dir, audio_name) + mixture_path = os.path.join(mixture_dir, audio_name) + + hdf5_path = os.path.join( + hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem) + ) + + param = ( + audio_index, + audio_name, + speech_path, + mixture_path, + mono, + sample_rate, + hdf5_path, + ) + params.append(param) + + # Uncomment for debug. + # write_single_audio_to_hdf5(params[0]) + # os._exit(0) + + pack_hdf5s_time = time.time() + + with ProcessPoolExecutor(max_workers=None) as pool: + # Maximum works on the machine + pool.map(write_single_audio_to_hdf5, params) + + print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time)) + + +def write_single_audio_to_hdf5(param: List) -> NoReturn: + r"""Write single audio into hdf5 file.""" + + ( + audio_index, + audio_name, + speech_path, + mixture_path, + mono, + sample_rate, + hdf5_path, + ) = param + + with h5py.File(hdf5_path, "w") as hf: + + hf.attrs.create("audio_name", data=audio_name, dtype="S100") + hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32) + + speech = load_audio(audio_path=speech_path, mono=mono, sample_rate=sample_rate) + # speech: (channels_num, audio_samples) + + mixture = load_audio( + audio_path=mixture_path, mono=mono, sample_rate=sample_rate + ) + # mixture: (channels_num, audio_samples) + + noise = mixture - speech + # noise: (channels_num, audio_samples) + + hf.create_dataset(name='speech', data=float32_to_int16(speech), dtype=np.int16) + hf.create_dataset(name='noise', data=float32_to_int16(noise), dtype=np.int16) + + print('{} Write hdf5 to {}'.format(audio_index, hdf5_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset_dir", + type=str, + required=True, + help="Directory of the Voicebank-Demand dataset.", + ) + parser.add_argument("--split", type=str, required=True, choices=["train", "test"]) + parser.add_argument( + "--hdf5s_dir", + type=str, + required=True, + help="Directory to write out hdf5 files.", + ) + parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.") + parser.add_argument( + "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo." + ) + + # Parse arguments. + args = parser.parse_args() + + # Pack audios into hdf5 files. + pack_audios_to_hdf5s(args) diff --git a/bytesep/inference.py b/bytesep/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7556d239273cf568528e7d2a7e41feea32a81f --- /dev/null +++ b/bytesep/inference.py @@ -0,0 +1,402 @@ +import argparse +import os +import time +from typing import Dict +import pathlib + +import librosa +import numpy as np +import soundfile +import torch +import torch.nn as nn + +from bytesep.models.lightning_modules import get_model_class +from bytesep.utils import read_yaml + + +class Separator: + def __init__( + self, model: nn.Module, segment_samples: int, batch_size: int, device: str + ): + r"""Separate to separate an audio clip into a target source. + + Args: + model: nn.Module, trained model + segment_samples: int, length of segments to be input to a model, e.g., 44100*30 + batch_size, int, e.g., 12 + device: str, e.g., 'cuda' + """ + self.model = model + self.segment_samples = segment_samples + self.batch_size = batch_size + self.device = device + + def separate(self, input_dict: Dict) -> np.array: + r"""Separate an audio clip into a target source. + + Args: + input_dict: dict, e.g., { + waveform: (channels_num, audio_samples), + ..., + } + + Returns: + sep_audio: (channels_num, audio_samples) | (target_sources_num, channels_num, audio_samples) + """ + audio = input_dict['waveform'] + + audio_samples = audio.shape[-1] + + # Pad the audio with zero in the end so that the length of audio can be + # evenly divided by segment_samples. + audio = self.pad_audio(audio) + + # Enframe long audio into segments. + segments = self.enframe(audio, self.segment_samples) + # (segments_num, channels_num, segment_samples) + + segments_input_dict = {'waveform': segments} + + if 'condition' in input_dict.keys(): + segments_num = len(segments) + segments_input_dict['condition'] = np.tile( + input_dict['condition'][None, :], (segments_num, 1) + ) + # (batch_size, segments_num) + + # Separate in mini-batches. + sep_segments = self._forward_in_mini_batches( + self.model, segments_input_dict, self.batch_size + )['waveform'] + # (segments_num, channels_num, segment_samples) + + # Deframe segments into long audio. + sep_audio = self.deframe(sep_segments) + # (channels_num, padded_audio_samples) + + sep_audio = sep_audio[:, 0:audio_samples] + # (channels_num, audio_samples) + + return sep_audio + + def pad_audio(self, audio: np.array) -> np.array: + r"""Pad the audio with zero in the end so that the length of audio can + be evenly divided by segment_samples. + + Args: + audio: (channels_num, audio_samples) + + Returns: + padded_audio: (channels_num, audio_samples) + """ + channels_num, audio_samples = audio.shape + + # Number of segments + segments_num = int(np.ceil(audio_samples / self.segment_samples)) + + pad_samples = segments_num * self.segment_samples - audio_samples + + padded_audio = np.concatenate( + (audio, np.zeros((channels_num, pad_samples))), axis=1 + ) + # (channels_num, padded_audio_samples) + + return padded_audio + + def enframe(self, audio: np.array, segment_samples: int) -> np.array: + r"""Enframe long audio into segments. + + Args: + audio: (channels_num, audio_samples) + segment_samples: int + + Returns: + segments: (segments_num, channels_num, segment_samples) + """ + audio_samples = audio.shape[1] + assert audio_samples % segment_samples == 0 + + hop_samples = segment_samples // 2 + segments = [] + + pointer = 0 + while pointer + segment_samples <= audio_samples: + segments.append(audio[:, pointer : pointer + segment_samples]) + pointer += hop_samples + + segments = np.array(segments) + + return segments + + def deframe(self, segments: np.array) -> np.array: + r"""Deframe segments into long audio. + + Args: + segments: (segments_num, channels_num, segment_samples) + + Returns: + output: (channels_num, audio_samples) + """ + (segments_num, _, segment_samples) = segments.shape + + if segments_num == 1: + return segments[0] + + assert self._is_integer(segment_samples * 0.25) + assert self._is_integer(segment_samples * 0.75) + + output = [] + + output.append(segments[0, :, 0 : int(segment_samples * 0.75)]) + + for i in range(1, segments_num - 1): + output.append( + segments[ + i, :, int(segment_samples * 0.25) : int(segment_samples * 0.75) + ] + ) + + output.append(segments[-1, :, int(segment_samples * 0.25) :]) + + output = np.concatenate(output, axis=-1) + + return output + + def _is_integer(self, x: float) -> bool: + if x - int(x) < 1e-10: + return True + else: + return False + + def _forward_in_mini_batches( + self, model: nn.Module, segments_input_dict: Dict, batch_size: int + ) -> Dict: + r"""Forward data to model in mini-batch. + + Args: + model: nn.Module + segments_input_dict: dict, e.g., { + 'waveform': (segments_num, channels_num, segment_samples), + ..., + } + batch_size: int + + Returns: + output_dict: dict, e.g. { + 'waveform': (segments_num, channels_num, segment_samples), + } + """ + output_dict = {} + + pointer = 0 + segments_num = len(segments_input_dict['waveform']) + + while True: + if pointer >= segments_num: + break + + batch_input_dict = {} + + for key in segments_input_dict.keys(): + batch_input_dict[key] = torch.Tensor( + segments_input_dict[key][pointer : pointer + batch_size] + ).to(self.device) + + pointer += batch_size + + with torch.no_grad(): + model.eval() + batch_output_dict = model(batch_input_dict) + + for key in batch_output_dict.keys(): + self._append_to_dict( + output_dict, key, batch_output_dict[key].data.cpu().numpy() + ) + + for key in output_dict.keys(): + output_dict[key] = np.concatenate(output_dict[key], axis=0) + + return output_dict + + def _append_to_dict(self, dict, key, value): + if key in dict.keys(): + dict[key].append(value) + else: + dict[key] = [value] + + +class SeparatorWrapper: + def __init__( + self, source_type='vocals', model=None, checkpoint_path=None, device='cuda' + ): + + input_channels = 2 + target_sources_num = 1 + model_type = "ResUNet143_Subbandtime" + segment_samples = 44100 * 10 + batch_size = 1 + + self.checkpoint_path = self.download_checkpoints(checkpoint_path, source_type) + + if device == 'cuda' and torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + + # Get model class. + Model = get_model_class(model_type) + + # Create model. + self.model = Model( + input_channels=input_channels, target_sources_num=target_sources_num + ) + + # Load checkpoint. + checkpoint = torch.load(self.checkpoint_path, map_location='cpu') + self.model.load_state_dict(checkpoint["model"]) + + # Move model to device. + self.model.to(self.device) + + # Create separator. + self.separator = Separator( + model=self.model, + segment_samples=segment_samples, + batch_size=batch_size, + device=self.device, + ) + + def download_checkpoints(self, checkpoint_path, source_type): + + if source_type == "vocals": + checkpoint_bare_name = "resunet143_subbtandtime_vocals_8.8dB_350k_steps" + + elif source_type == "accompaniment": + checkpoint_bare_name = ( + "resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth" + ) + + else: + raise NotImplementedError + + if not checkpoint_path: + checkpoint_path = '{}/bytesep_data/{}.pth'.format( + str(pathlib.Path.home()), checkpoint_bare_name + ) + + print('Checkpoint path: {}'.format(checkpoint_path)) + + if ( + not os.path.exists(checkpoint_path) + or os.path.getsize(checkpoint_path) < 4e8 + ): + + os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + + zenodo_dir = "https://zenodo.org/record/5507029/files" + zenodo_path = os.path.join( + zenodo_dir, "{}?download=1".format(checkpoint_bare_name) + ) + + os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path)) + + return checkpoint_path + + def separate(self, audio): + + input_dict = {'waveform': audio} + + sep_wav = self.separator.separate(input_dict) + + return sep_wav + + +def inference(args): + + # Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync. + import torch.distributed as dist + + dist.init_process_group( + 'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1 + ) + + # Arguments & parameters + config_yaml = args.config_yaml + checkpoint_path = args.checkpoint_path + audio_path = args.audio_path + output_path = args.output_path + device = ( + torch.device('cuda') + if args.cuda and torch.cuda.is_available() + else torch.device('cpu') + ) + + configs = read_yaml(config_yaml) + sample_rate = configs['train']['sample_rate'] + input_channels = configs['train']['channels'] + target_source_types = configs['train']['target_source_types'] + target_sources_num = len(target_source_types) + model_type = configs['train']['model_type'] + + segment_samples = int(30 * sample_rate) + batch_size = 1 + + print("Using {} for separating ..".format(device)) + + # paths + if os.path.dirname(output_path) != "": + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Get model class. + Model = get_model_class(model_type) + + # Create model. + model = Model(input_channels=input_channels, target_sources_num=target_sources_num) + + # Load checkpoint. + checkpoint = torch.load(checkpoint_path, map_location='cpu') + model.load_state_dict(checkpoint["model"]) + + # Move model to device. + model.to(device) + + # Create separator. + separator = Separator( + model=model, + segment_samples=segment_samples, + batch_size=batch_size, + device=device, + ) + + # Load audio. + audio, _ = librosa.load(audio_path, sr=sample_rate, mono=False) + + # audio = audio[None, :] + + input_dict = {'waveform': audio} + + # Separate + separate_time = time.time() + + sep_wav = separator.separate(input_dict) + # (channels_num, audio_samples) + + print('Separate time: {:.3f} s'.format(time.time() - separate_time)) + + # Write out separated audio. + soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate) + os.system("ffmpeg -y -loglevel panic -i _zz.wav {}".format(output_path)) + print('Write out to {}'.format(output_path)) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="") + parser.add_argument("--config_yaml", type=str, required=True) + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--audio_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--cuda", action='store_true', default=True) + + args = parser.parse_args() + inference(args) diff --git a/bytesep/inference_many.py b/bytesep/inference_many.py new file mode 100644 index 0000000000000000000000000000000000000000..154eb7d64a93dfd670dc7cad56c4e67eb8a63fe3 --- /dev/null +++ b/bytesep/inference_many.py @@ -0,0 +1,163 @@ +import argparse +import os +import pathlib +import time +from typing import NoReturn + +import librosa +import numpy as np +import soundfile +import torch + +from bytesep.inference import Separator +from bytesep.models.lightning_modules import get_model_class +from bytesep.utils import read_yaml + + +def inference(args) -> NoReturn: + r"""Separate all audios in a directory. + + Args: + config_yaml: str, the config file of a model being trained + checkpoint_path: str, the path of checkpoint to be loaded + audios_dir: str, the directory of audios to be separated + output_dir: str, the directory to write out separated audios + scale_volume: bool, if True then the volume is scaled to the maximum value of 1. + + Returns: + NoReturn + """ + + # Arguments & parameters + config_yaml = args.config_yaml + checkpoint_path = args.checkpoint_path + audios_dir = args.audios_dir + output_dir = args.output_dir + scale_volume = args.scale_volume + device = ( + torch.device('cuda') + if args.cuda and torch.cuda.is_available() + else torch.device('cpu') + ) + + configs = read_yaml(config_yaml) + sample_rate = configs['train']['sample_rate'] + input_channels = configs['train']['channels'] + target_source_types = configs['train']['target_source_types'] + target_sources_num = len(target_source_types) + model_type = configs['train']['model_type'] + mono = input_channels == 1 + + segment_samples = int(30 * sample_rate) + batch_size = 1 + device = "cuda" + + models_contains_inplaceabn = True + + # Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync. + if models_contains_inplaceabn: + + import torch.distributed as dist + + dist.init_process_group( + 'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1 + ) + + print("Using {} for separating ..".format(device)) + + # paths + os.makedirs(output_dir, exist_ok=True) + + # Get model class. + Model = get_model_class(model_type) + + # Create model. + model = Model(input_channels=input_channels, target_sources_num=target_sources_num) + + # Load checkpoint. + checkpoint = torch.load(checkpoint_path, map_location='cpu') + model.load_state_dict(checkpoint["model"]) + + # Move model to device. + model.to(device) + + # Create separator. + separator = Separator( + model=model, + segment_samples=segment_samples, + batch_size=batch_size, + device=device, + ) + + audio_names = sorted(os.listdir(audios_dir)) + + for audio_name in audio_names: + audio_path = os.path.join(audios_dir, audio_name) + + # Load audio. + audio, _ = librosa.load(audio_path, sr=sample_rate, mono=mono) + + if audio.ndim == 1: + audio = audio[None, :] + + input_dict = {'waveform': audio} + + # Separate + separate_time = time.time() + + sep_wav = separator.separate(input_dict) + # (channels_num, audio_samples) + + print('Separate time: {:.3f} s'.format(time.time() - separate_time)) + + # Write out separated audio. + if scale_volume: + sep_wav /= np.max(np.abs(sep_wav)) + + soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate) + + output_path = os.path.join( + output_dir, '{}.mp3'.format(pathlib.Path(audio_name).stem) + ) + os.system('ffmpeg -y -loglevel panic -i _zz.wav "{}"'.format(output_path)) + print('Write out to {}'.format(output_path)) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "--config_yaml", + type=str, + required=True, + help="The config file of a model being trained.", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="The path of checkpoint to be loaded.", + ) + parser.add_argument( + "--audios_dir", + type=str, + required=True, + help="The directory of audios to be separated.", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="The directory to write out separated audios.", + ) + parser.add_argument( + '--scale_volume', + action='store_true', + default=False, + help="set to True if separated audios are scaled to the maximum value of 1.", + ) + parser.add_argument("--cuda", action='store_true', default=True) + + args = parser.parse_args() + + inference(args) diff --git a/bytesep/losses.py b/bytesep/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..58e79fb10c3a6cec7493ddd77e9137ab7ddc1de3 --- /dev/null +++ b/bytesep/losses.py @@ -0,0 +1,106 @@ +import math +from typing import Callable + +import torch +import torch.nn as nn +from torchlibrosa.stft import STFT + +from bytesep.models.pytorch_modules import Base + + +def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor: + r"""L1 loss. + + Args: + output: torch.Tensor + target: torch.Tensor + + Returns: + loss: torch.float + """ + return torch.mean(torch.abs(output - target)) + + +def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor: + r"""L1 loss in the time-domain. + + Args: + output: torch.Tensor + target: torch.Tensor + + Returns: + loss: torch.float + """ + return l1(output, target) + + +class L1_Wav_L1_Sp(nn.Module, Base): + def __init__(self): + r"""L1 loss in the time-domain and L1 loss on the spectrogram.""" + super(L1_Wav_L1_Sp, self).__init__() + + self.window_size = 2048 + hop_size = 441 + center = True + pad_mode = "reflect" + window = "hann" + + self.stft = STFT( + n_fft=self.window_size, + hop_length=hop_size, + win_length=self.window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + def __call__( + self, output: torch.Tensor, target: torch.Tensor, **kwargs + ) -> torch.Tensor: + r"""L1 loss in the time-domain and on the spectrogram. + + Args: + output: torch.Tensor + target: torch.Tensor + + Returns: + loss: torch.float + """ + + # L1 loss in the time-domain. + wav_loss = l1_wav(output, target) + + # L1 loss on the spectrogram. + sp_loss = l1( + self.wav_to_spectrogram(output, eps=1e-8), + self.wav_to_spectrogram(target, eps=1e-8), + ) + + # sp_loss /= math.sqrt(self.window_size) + # sp_loss *= 1. + + # Total loss. + return wav_loss + sp_loss + + return sp_loss + + +def get_loss_function(loss_type: str) -> Callable: + r"""Get loss function. + + Args: + loss_type: str + + Returns: + loss function: Callable + """ + + if loss_type == "l1_wav": + return l1_wav + + elif loss_type == "l1_wav_l1_sp": + return L1_Wav_L1_Sp() + + else: + raise NotImplementedError diff --git a/bytesep/models/__init__.py b/bytesep/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bytesep/models/conditional_unet.py b/bytesep/models/conditional_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..1e925c11308b04ba195db83b08c2718930b1b4c6 --- /dev/null +++ b/bytesep/models/conditional_unet.py @@ -0,0 +1,496 @@ +import math +from typing import List + +import numpy as np +import matplotlib.pyplot as plt +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR +from torchlibrosa.stft import STFT, ISTFT, magphase + +from bytesep.models.pytorch_modules import ( + Base, + init_bn, + init_embedding, + init_layer, + act, + Subband, +) + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + condition_size, + kernel_size, + activation, + momentum, + ): + super(ConvBlock, self).__init__() + + self.activation = activation + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) + + self.beta1 = nn.Linear(condition_size, out_channels, bias=True) + self.beta2 = nn.Linear(condition_size, out_channels, bias=True) + + self.init_weights() + + def init_weights(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + init_embedding(self.beta1) + init_embedding(self.beta2) + + def forward(self, x, condition): + + b1 = self.beta1(condition)[:, :, None, None] + b2 = self.beta2(condition)[:, :, None, None] + + x = act(self.bn1(self.conv1(x)) + b1, self.activation) + x = act(self.bn2(self.conv2(x)) + b2, self.activation) + return x + + +class EncoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + condition_size, + kernel_size, + downsample, + activation, + momentum, + ): + super(EncoderBlock, self).__init__() + + self.conv_block = ConvBlock( + in_channels, out_channels, condition_size, kernel_size, activation, momentum + ) + self.downsample = downsample + + def forward(self, x, condition): + encoder = self.conv_block(x, condition) + encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample) + return encoder_pool, encoder + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + condition_size, + kernel_size, + upsample, + activation, + momentum, + ): + super(DecoderBlock, self).__init__() + self.kernel_size = kernel_size + self.stride = upsample + self.activation = activation + + self.conv1 = torch.nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.stride, + stride=self.stride, + padding=(0, 0), + bias=False, + dilation=(1, 1), + ) + + self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum) + + self.conv_block2 = ConvBlock( + out_channels * 2, + out_channels, + condition_size, + kernel_size, + activation, + momentum, + ) + + self.beta1 = nn.Linear(condition_size, out_channels, bias=True) + + self.init_weights() + + def init_weights(self): + init_layer(self.conv1) + init_bn(self.bn1) + init_embedding(self.beta1) + + def forward(self, input_tensor, concat_tensor, condition): + b1 = self.beta1(condition)[:, :, None, None] + x = act(self.bn1(self.conv1(input_tensor)) + b1, self.activation) + x = torch.cat((x, concat_tensor), dim=1) + x = self.conv_block2(x, condition) + return x + + +class ConditionalUNet(nn.Module, Base): + def __init__(self, input_channels, target_sources_num): + super(ConditionalUNet, self).__init__() + + self.input_channels = input_channels + condition_size = target_sources_num + self.output_sources_num = 1 + + window_size = 2048 + hop_size = 441 + center = True + pad_mode = "reflect" + window = "hann" + activation = "relu" + momentum = 0.01 + + self.subbands_num = 4 + self.K = 3 # outputs: |M|, cos∠M, sin∠M + + self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks} + + self.stft = STFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.istft = ISTFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum) + + self.subband = Subband(subbands_num=self.subbands_num) + + self.encoder_block1 = EncoderBlock( + in_channels=input_channels * self.subbands_num, + out_channels=32, + condition_size=condition_size, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block2 = EncoderBlock( + in_channels=32, + out_channels=64, + condition_size=condition_size, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block3 = EncoderBlock( + in_channels=64, + out_channels=128, + condition_size=condition_size, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block4 = EncoderBlock( + in_channels=128, + out_channels=256, + condition_size=condition_size, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block5 = EncoderBlock( + in_channels=256, + out_channels=384, + condition_size=condition_size, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block6 = EncoderBlock( + in_channels=384, + out_channels=384, + condition_size=condition_size, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.conv_block7 = ConvBlock( + in_channels=384, + out_channels=384, + condition_size=condition_size, + kernel_size=(3, 3), + activation=activation, + momentum=momentum, + ) + self.decoder_block1 = DecoderBlock( + in_channels=384, + out_channels=384, + condition_size=condition_size, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block2 = DecoderBlock( + in_channels=384, + out_channels=384, + condition_size=condition_size, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block3 = DecoderBlock( + in_channels=384, + out_channels=256, + condition_size=condition_size, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block4 = DecoderBlock( + in_channels=256, + out_channels=128, + condition_size=condition_size, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block5 = DecoderBlock( + in_channels=128, + out_channels=64, + condition_size=condition_size, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block6 = DecoderBlock( + in_channels=64, + out_channels=32, + condition_size=condition_size, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + + self.after_conv_block1 = ConvBlock( + in_channels=32, + out_channels=32, + condition_size=condition_size, + kernel_size=(3, 3), + activation=activation, + momentum=momentum, + ) + + self.after_conv2 = nn.Conv2d( + in_channels=32, + out_channels=input_channels + * self.subbands_num + * self.output_sources_num + * self.K, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn0) + init_layer(self.after_conv2) + + def feature_maps_to_wav(self, x, sp, sin_in, cos_in, audio_length): + + batch_size, _, time_steps, freq_bins = x.shape + + x = x.reshape( + batch_size, + self.output_sources_num, + self.input_channels, + self.K, + time_steps, + freq_bins, + ) + # x: (batch_size, output_sources_num, input_channles, K, time_steps, freq_bins) + + mask_mag = torch.sigmoid(x[:, :, :, 0, :, :]) + _mask_real = torch.tanh(x[:, :, :, 1, :, :]) + _mask_imag = torch.tanh(x[:, :, :, 2, :, :]) + _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag) + # mask_cos, mask_sin: (batch_size, output_sources_num, input_channles, time_steps, freq_bins) + + # Y = |Y|cos∠Y + j|Y|sin∠Y + # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M) + # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M) + out_cos = ( + cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin + ) + out_sin = ( + sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin + ) + # out_cos, out_sin: (batch_size, output_sources_num, input_channles, time_steps, freq_bins) + + # Calculate |Y|. + out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag) + # out_mag: (batch_size, output_sources_num, input_channles, time_steps, freq_bins) + + # Calculate Y_{real} and Y_{imag} for ISTFT. + out_real = out_mag * out_cos + out_imag = out_mag * out_sin + # out_real, out_imag: (batch_size, output_sources_num, input_channles, time_steps, freq_bins) + + # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT. + shape = ( + batch_size * self.output_sources_num * self.input_channels, + 1, + time_steps, + freq_bins, + ) + out_real = out_real.reshape(shape) + out_imag = out_imag.reshape(shape) + + # ISTFT. + wav_out = self.istft(out_real, out_imag, audio_length) + # (batch_size * output_sources_num * input_channels, segments_num) + + # Reshape. + wav_out = wav_out.reshape( + batch_size, self.output_sources_num * self.input_channels, audio_length + ) + # (batch_size, output_sources_num * input_channels, segments_num) + + return wav_out + + def forward(self, input_dict): + """ + Args: + input: (batch_size, segment_samples, channels_num) + + Outputs: + output_dict: { + 'wav': (batch_size, segment_samples, channels_num), + 'sp': (batch_size, channels_num, time_steps, freq_bins)} + """ + + mixture = input_dict['waveform'] + condition = input_dict['condition'] + + sp, cos_in, sin_in = self.wav_to_spectrogram_phase(mixture) + """(batch_size, channels_num, time_steps, freq_bins)""" + + # Batch normalization + x = sp.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + """(batch_size, chanenls, time_steps, freq_bins)""" + + # Pad spectrogram to be evenly divided by downsample ratio. + origin_len = x.shape[2] + pad_len = ( + int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio + - origin_len + ) + x = F.pad(x, pad=(0, 0, 0, pad_len)) + """(batch_size, channels, padded_time_steps, freq_bins)""" + + # Let frequency bins be evenly divided by 2, e.g., 513 -> 512 + x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F) + + x = self.subband.analysis(x) + + # UNet + (x1_pool, x1) = self.encoder_block1( + x, condition + ) # x1_pool: (bs, 32, T / 2, F / 2) + (x2_pool, x2) = self.encoder_block2( + x1_pool, condition + ) # x2_pool: (bs, 64, T / 4, F / 4) + (x3_pool, x3) = self.encoder_block3( + x2_pool, condition + ) # x3_pool: (bs, 128, T / 8, F / 8) + (x4_pool, x4) = self.encoder_block4( + x3_pool, condition + ) # x4_pool: (bs, 256, T / 16, F / 16) + (x5_pool, x5) = self.encoder_block5( + x4_pool, condition + ) # x5_pool: (bs, 512, T / 32, F / 32) + (x6_pool, x6) = self.encoder_block6( + x5_pool, condition + ) # x6_pool: (bs, 1024, T / 64, F / 64) + x_center = self.conv_block7(x6_pool, condition) # (bs, 2048, T / 64, F / 64) + x7 = self.decoder_block1(x_center, x6, condition) # (bs, 1024, T / 32, F / 32) + x8 = self.decoder_block2(x7, x5, condition) # (bs, 512, T / 16, F / 16) + x9 = self.decoder_block3(x8, x4, condition) # (bs, 256, T / 8, F / 8) + x10 = self.decoder_block4(x9, x3, condition) # (bs, 128, T / 4, F / 4) + x11 = self.decoder_block5(x10, x2, condition) # (bs, 64, T / 2, F / 2) + x12 = self.decoder_block6(x11, x1, condition) # (bs, 32, T, F) + x = self.after_conv_block1(x12, condition) # (bs, 32, T, F) + x = self.after_conv2(x) + # (batch_size, input_channles * subbands_num * targets_num * k, T, F // subbands_num) + + x = self.subband.synthesis(x) + # (batch_size, input_channles * targets_num * K, T, F) + + # Recover shape + x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025. + x = x[:, :, 0:origin_len, :] # (bs, feature_maps, T, F) + + audio_length = mixture.shape[2] + + separated_audio = self.feature_maps_to_wav(x, sp, sin_in, cos_in, audio_length) + # separated_audio: (batch_size, output_sources_num * input_channels, segments_num) + + output_dict = {'waveform': separated_audio} + + return output_dict diff --git a/bytesep/models/lightning_modules.py b/bytesep/models/lightning_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..698f04116ed94521eb49e1995d967534dba62cd9 --- /dev/null +++ b/bytesep/models/lightning_modules.py @@ -0,0 +1,188 @@ +from typing import Any, Callable, Dict + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR + + +class LitSourceSeparation(pl.LightningModule): + def __init__( + self, + batch_data_preprocessor, + model: nn.Module, + loss_function: Callable, + optimizer_type: str, + learning_rate: float, + lr_lambda: Callable, + ): + r"""Pytorch Lightning wrapper of PyTorch model, including forward, + optimization of model, etc. + + Args: + batch_data_preprocessor: object, used for preparing inputs and + targets for training. E.g., BasicBatchDataPreprocessor is used + for preparing data in dictionary into tensor. + model: nn.Module + loss_function: function + learning_rate: float + lr_lambda: function + """ + super().__init__() + + self.batch_data_preprocessor = batch_data_preprocessor + self.model = model + self.optimizer_type = optimizer_type + self.loss_function = loss_function + self.learning_rate = learning_rate + self.lr_lambda = lr_lambda + + def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float: + r"""Forward a mini-batch data to model, calculate loss function, and + train for one step. A mini-batch data is evenly distributed to multiple + devices (if there are) for parallel training. + + Args: + batch_data_dict: e.g. { + 'vocals': (batch_size, channels_num, segment_samples), + 'accompaniment': (batch_size, channels_num, segment_samples), + 'mixture': (batch_size, channels_num, segment_samples) + } + batch_idx: int + + Returns: + loss: float, loss function of this mini-batch + """ + input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict) + # input_dict: { + # 'waveform': (batch_size, channels_num, segment_samples), + # (if_exist) 'condition': (batch_size, channels_num), + # } + # target_dict: { + # 'waveform': (batch_size, target_sources_num * channels_num, segment_samples), + # } + + # Forward. + self.model.train() + + output_dict = self.model(input_dict) + # output_dict: { + # 'waveform': (batch_size, target_sources_num * channels_num, segment_samples), + # } + + outputs = output_dict['waveform'] + # outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples) + + # Calculate loss. + loss = self.loss_function( + output=outputs, + target=target_dict['waveform'], + mixture=input_dict['waveform'], + ) + + return loss + + def configure_optimizers(self) -> Any: + r"""Configure optimizer.""" + + if self.optimizer_type == "Adam": + optimizer = optim.Adam( + self.model.parameters(), + lr=self.learning_rate, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.0, + amsgrad=True, + ) + + elif self.optimizer_type == "AdamW": + optimizer = optim.AdamW( + self.model.parameters(), + lr=self.learning_rate, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0.0, + amsgrad=True, + ) + + else: + raise NotImplementedError + + scheduler = { + 'scheduler': LambdaLR(optimizer, self.lr_lambda), + 'interval': 'step', + 'frequency': 1, + } + + return [optimizer], [scheduler] + + +def get_model_class(model_type): + r"""Get model. + + Args: + model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN' + + Returns: + nn.Module + """ + if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021': + from bytesep.models.resunet_ismir2021 import ( + ResUNet143_DecouplePlusInplaceABN_ISMIR2021, + ) + + return ResUNet143_DecouplePlusInplaceABN_ISMIR2021 + + elif model_type == 'UNet': + from bytesep.models.unet import UNet + + return UNet + + elif model_type == 'UNetSubbandTime': + from bytesep.models.unet_subbandtime import UNetSubbandTime + + return UNetSubbandTime + + elif model_type == 'ResUNet143_Subbandtime': + from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime + + return ResUNet143_Subbandtime + + elif model_type == 'ResUNet143_DecouplePlus': + from bytesep.models.resunet import ResUNet143_DecouplePlus + + return ResUNet143_DecouplePlus + + elif model_type == 'ConditionalUNet': + from bytesep.models.conditional_unet import ConditionalUNet + + return ConditionalUNet + + elif model_type == 'LevelRNN': + from bytesep.models.levelrnn import LevelRNN + + return LevelRNN + + elif model_type == 'WavUNet': + from bytesep.models.wavunet import WavUNet + + return WavUNet + + elif model_type == 'WavUNetLevelRNN': + from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN + + return WavUNetLevelRNN + + elif model_type == 'TTnet': + from bytesep.models.ttnet import TTnet + + return TTnet + + elif model_type == 'TTnetNoTransformer': + from bytesep.models.ttnet_no_transformer import TTnetNoTransformer + + return TTnetNoTransformer + + else: + raise NotImplementedError diff --git a/bytesep/models/pytorch_modules.py b/bytesep/models/pytorch_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc51f0945d2764b8428611a8ecf109a0b344884 --- /dev/null +++ b/bytesep/models/pytorch_modules.py @@ -0,0 +1,204 @@ +from typing import List, NoReturn + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def init_embedding(layer: nn.Module) -> NoReturn: + r"""Initialize a Linear or Convolutional layer.""" + nn.init.uniform_(layer.weight, -1.0, 1.0) + + if hasattr(layer, 'bias'): + if layer.bias is not None: + layer.bias.data.fill_(0.0) + + +def init_layer(layer: nn.Module) -> NoReturn: + r"""Initialize a Linear or Convolutional layer.""" + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, "bias"): + if layer.bias is not None: + layer.bias.data.fill_(0.0) + + +def init_bn(bn: nn.Module) -> NoReturn: + r"""Initialize a Batchnorm layer.""" + bn.bias.data.fill_(0.0) + bn.weight.data.fill_(1.0) + bn.running_mean.data.fill_(0.0) + bn.running_var.data.fill_(1.0) + + +def act(x: torch.Tensor, activation: str) -> torch.Tensor: + + if activation == "relu": + return F.relu_(x) + + elif activation == "leaky_relu": + return F.leaky_relu_(x, negative_slope=0.01) + + elif activation == "swish": + return x * torch.sigmoid(x) + + else: + raise Exception("Incorrect activation!") + + +class Base: + def __init__(self): + r"""Base function for extracting spectrogram, cos, and sin, etc.""" + pass + + def spectrogram(self, input: torch.Tensor, eps: float = 0.0) -> torch.Tensor: + r"""Calculate spectrogram. + + Args: + input: (batch_size, segments_num) + eps: float + + Returns: + spectrogram: (batch_size, time_steps, freq_bins) + """ + (real, imag) = self.stft(input) + return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 + + def spectrogram_phase( + self, input: torch.Tensor, eps: float = 0.0 + ) -> List[torch.Tensor]: + r"""Calculate the magnitude, cos, and sin of the STFT of input. + + Args: + input: (batch_size, segments_num) + eps: float + + Returns: + mag: (batch_size, time_steps, freq_bins) + cos: (batch_size, time_steps, freq_bins) + sin: (batch_size, time_steps, freq_bins) + """ + (real, imag) = self.stft(input) + mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 + cos = real / mag + sin = imag / mag + return mag, cos, sin + + def wav_to_spectrogram_phase( + self, input: torch.Tensor, eps: float = 1e-10 + ) -> List[torch.Tensor]: + r"""Convert waveforms to magnitude, cos, and sin of STFT. + + Args: + input: (batch_size, channels_num, segment_samples) + eps: float + + Outputs: + mag: (batch_size, channels_num, time_steps, freq_bins) + cos: (batch_size, channels_num, time_steps, freq_bins) + sin: (batch_size, channels_num, time_steps, freq_bins) + """ + batch_size, channels_num, segment_samples = input.shape + + # Reshape input with shapes of (n, segments_num) to meet the + # requirements of the stft function. + x = input.reshape(batch_size * channels_num, segment_samples) + + mag, cos, sin = self.spectrogram_phase(x, eps=eps) + # mag, cos, sin: (batch_size * channels_num, 1, time_steps, freq_bins) + + _, _, time_steps, freq_bins = mag.shape + mag = mag.reshape(batch_size, channels_num, time_steps, freq_bins) + cos = cos.reshape(batch_size, channels_num, time_steps, freq_bins) + sin = sin.reshape(batch_size, channels_num, time_steps, freq_bins) + + return mag, cos, sin + + def wav_to_spectrogram( + self, input: torch.Tensor, eps: float = 1e-10 + ) -> List[torch.Tensor]: + + mag, cos, sin = self.wav_to_spectrogram_phase(input, eps) + return mag + + +class Subband: + def __init__(self, subbands_num: int): + r"""Warning!! This class is not used!! + + This class does not work as good as [1] which split subbands in the + time-domain. Please refere to [1] for formal implementation. + + [1] Liu, Haohe, et al. "Channel-wise subband input for better voice and + accompaniment separation on high resolution music." arXiv preprint arXiv:2008.05216 (2020). + + Args: + subbands_num: int, e.g., 4 + """ + self.subbands_num = subbands_num + + def analysis(self, x: torch.Tensor) -> torch.Tensor: + r"""Analysis time-frequency representation into subbands. Stack the + subbands along the channel axis. + + Args: + x: (batch_size, channels_num, time_steps, freq_bins) + + Returns: + output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num) + """ + batch_size, channels_num, time_steps, freq_bins = x.shape + + x = x.reshape( + batch_size, + channels_num, + time_steps, + self.subbands_num, + freq_bins // self.subbands_num, + ) + # x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num) + + x = x.transpose(2, 3) + + output = x.reshape( + batch_size, + channels_num * self.subbands_num, + time_steps, + freq_bins // self.subbands_num, + ) + # output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num) + + return output + + def synthesis(self, x: torch.Tensor) -> torch.Tensor: + r"""Synthesis subband time-frequency representations into original + time-frequency representation. + + Args: + x: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num) + + Returns: + output: (batch_size, channels_num, time_steps, freq_bins) + """ + batch_size, subband_channels_num, time_steps, subband_freq_bins = x.shape + + channels_num = subband_channels_num // self.subbands_num + freq_bins = subband_freq_bins * self.subbands_num + + x = x.reshape( + batch_size, + channels_num, + self.subbands_num, + time_steps, + subband_freq_bins, + ) + # x: (batch_size, channels_num, subbands_num, time_steps, freq_bins // subbands_num) + + x = x.transpose(2, 3) + # x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num) + + output = x.reshape(batch_size, channels_num, time_steps, freq_bins) + # x: (batch_size, channels_num, time_steps, freq_bins) + + return output diff --git a/bytesep/models/resunet.py b/bytesep/models/resunet.py new file mode 100644 index 0000000000000000000000000000000000000000..0f3939bc03da821898731484330e5dcc098cafc0 --- /dev/null +++ b/bytesep/models/resunet.py @@ -0,0 +1,516 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import ISTFT, STFT, magphase + +from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, activation, momentum): + r"""Residual block.""" + super(ConvBlockRes, self).__init__() + + self.activation = activation + padding = [kernel_size[0] // 2, kernel_size[1] // 2] + + self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) + self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + if in_channels != out_channels: + self.shortcut = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + ) + + self.is_shortcut = True + else: + self.is_shortcut = False + + self.init_weights() + + def init_weights(self): + init_bn(self.bn1) + init_bn(self.bn2) + init_layer(self.conv1) + init_layer(self.conv2) + + if self.is_shortcut: + init_layer(self.shortcut) + + def forward(self, x): + origin = x + x = self.conv1(act(self.bn1(x), self.activation)) + x = self.conv2(act(self.bn2(x), self.activation)) + + if self.is_shortcut: + return self.shortcut(origin) + x + else: + return origin + x + + +class EncoderBlockRes4B(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, downsample, activation, momentum + ): + r"""Encoder block, contains 8 convolutional layers.""" + super(EncoderBlockRes4B, self).__init__() + + self.conv_block1 = ConvBlockRes( + in_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block2 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block3 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block4 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.downsample = downsample + + def forward(self, x): + encoder = self.conv_block1(x) + encoder = self.conv_block2(encoder) + encoder = self.conv_block3(encoder) + encoder = self.conv_block4(encoder) + encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample) + return encoder_pool, encoder + + +class DecoderBlockRes4B(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, upsample, activation, momentum + ): + r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers.""" + super(DecoderBlockRes4B, self).__init__() + self.kernel_size = kernel_size + self.stride = upsample + self.activation = activation + + self.conv1 = torch.nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.stride, + stride=self.stride, + padding=(0, 0), + bias=False, + dilation=(1, 1), + ) + + self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) + self.conv_block2 = ConvBlockRes( + out_channels * 2, out_channels, kernel_size, activation, momentum + ) + self.conv_block3 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block4 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block5 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn1) + init_layer(self.conv1) + + def forward(self, input_tensor, concat_tensor): + x = self.conv1(act(self.bn1(input_tensor), self.activation)) + x = torch.cat((x, concat_tensor), dim=1) + x = self.conv_block2(x) + x = self.conv_block3(x) + x = self.conv_block4(x) + x = self.conv_block5(x) + return x + + +class ResUNet143_DecouplePlus(nn.Module, Base): + def __init__(self, input_channels, target_sources_num): + super(ResUNet143_DecouplePlus, self).__init__() + + self.input_channels = input_channels + self.target_sources_num = target_sources_num + + window_size = 2048 + hop_size = 441 + center = True + pad_mode = "reflect" + window = "hann" + activation = "relu" + momentum = 0.01 + + self.subbands_num = 4 + self.K = 4 # outputs: |M|, cos∠M, sin∠M, |M2| + + self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks} + + self.stft = STFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.istft = ISTFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum) + + self.subband = Subband(subbands_num=self.subbands_num) + + self.encoder_block1 = EncoderBlockRes4B( + in_channels=input_channels * self.subbands_num, + out_channels=32, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block2 = EncoderBlockRes4B( + in_channels=32, + out_channels=64, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block3 = EncoderBlockRes4B( + in_channels=64, + out_channels=128, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block4 = EncoderBlockRes4B( + in_channels=128, + out_channels=256, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block5 = EncoderBlockRes4B( + in_channels=256, + out_channels=384, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block6 = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 2), + activation=activation, + momentum=momentum, + ) + self.conv_block7a = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.conv_block7b = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.conv_block7c = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.conv_block7d = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.decoder_block1 = DecoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(1, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block2 = DecoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block3 = DecoderBlockRes4B( + in_channels=384, + out_channels=256, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block4 = DecoderBlockRes4B( + in_channels=256, + out_channels=128, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block5 = DecoderBlockRes4B( + in_channels=128, + out_channels=64, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block6 = DecoderBlockRes4B( + in_channels=64, + out_channels=32, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + + self.after_conv_block1 = EncoderBlockRes4B( + in_channels=32, + out_channels=32, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + + self.after_conv2 = nn.Conv2d( + in_channels=32, + out_channels=input_channels + * self.subbands_num + * target_sources_num + * self.K, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn0) + init_layer(self.after_conv2) + + def feature_maps_to_wav( + self, + input_tensor: torch.Tensor, + sp: torch.Tensor, + sin_in: torch.Tensor, + cos_in: torch.Tensor, + audio_length: int, + ) -> torch.Tensor: + r"""Convert feature maps to waveform. + + Args: + input_tensor: (batch_size, feature_maps, time_steps, freq_bins) + sp: (batch_size, feature_maps, time_steps, freq_bins) + sin_in: (batch_size, feature_maps, time_steps, freq_bins) + cos_in: (batch_size, feature_maps, time_steps, freq_bins) + + Outputs: + waveform: (batch_size, target_sources_num * input_channels, segment_samples) + """ + batch_size, _, time_steps, freq_bins = input_tensor.shape + + x = input_tensor.reshape( + batch_size, + self.target_sources_num, + self.input_channels, + self.K, + time_steps, + freq_bins, + ) + # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins) + + mask_mag = torch.sigmoid(x[:, :, :, 0, :, :]) + _mask_real = torch.tanh(x[:, :, :, 1, :, :]) + _mask_imag = torch.tanh(x[:, :, :, 2, :, :]) + linear_mag = x[:, :, :, 3, :, :] + _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag) + # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Y = |Y|cos∠Y + j|Y|sin∠Y + # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M) + # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M) + out_cos = ( + cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin + ) + out_sin = ( + sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin + ) + # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Calculate |Y|. + out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag) + # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Calculate Y_{real} and Y_{imag} for ISTFT. + out_real = out_mag * out_cos + out_imag = out_mag * out_sin + # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT. + shape = ( + batch_size * self.target_sources_num * self.input_channels, + 1, + time_steps, + freq_bins, + ) + out_real = out_real.reshape(shape) + out_imag = out_imag.reshape(shape) + + # ISTFT. + x = self.istft(out_real, out_imag, audio_length) + # (batch_size * target_sources_num * input_channels, segments_num) + + # Reshape. + waveform = x.reshape( + batch_size, self.target_sources_num * self.input_channels, audio_length + ) + # (batch_size, target_sources_num * input_channels, segments_num) + + return waveform + + def forward(self, input_dict): + r""" + Args: + input: (batch_size, channels_num, segment_samples) + + Outputs: + output_dict: { + 'wav': (batch_size, channels_num, segment_samples) + } + """ + mixtures = input_dict['waveform'] + # (batch_size, input_channels, segment_samples) + + mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures) + # mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins) + + # Batch normalize on individual frequency bins. + x = mag.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + """(batch_size, input_channels, time_steps, freq_bins)""" + + # Pad spectrogram to be evenly divided by downsample ratio. + origin_len = x.shape[2] + pad_len = ( + int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio + - origin_len + ) + x = F.pad(x, pad=(0, 0, 0, pad_len)) + """(batch_size, input_channels, padded_time_steps, freq_bins)""" + + # Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024 + x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F) + + x = self.subband.analysis(x) + # (bs, input_channels, T, F'), where F' = F // subbands_num + + # UNet + (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2) + (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4) + (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8) + (x4_pool, x4) = self.encoder_block4( + x3_pool + ) # x4_pool: (bs, 256, T / 16, F / 16) + (x5_pool, x5) = self.encoder_block5( + x4_pool + ) # x5_pool: (bs, 384, T / 32, F / 32) + (x6_pool, x6) = self.encoder_block6( + x5_pool + ) # x6_pool: (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64) + x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32) + x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16) + x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8) + x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4) + x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2) + x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F) + (x, _) = self.after_conv_block1(x12) # (bs, 32, T, F) + + x = self.after_conv2(x) # (bs, channels * 3, T, F) + # (batch_size, input_channles * subbands_num * targets_num * k, T, F') + + x = self.subband.synthesis(x) + # (batch_size, input_channles * targets_num * K, T, F) + + # Recover shape + x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025. + x = x[:, :, 0:origin_len, :] # (bs, feature_maps, time_steps, freq_bins) + + audio_length = mixtures.shape[2] + + separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length) + # separated_audio: (batch_size, target_sources_num * input_channels, segments_num) + + output_dict = {'waveform': separated_audio} + + return output_dict diff --git a/bytesep/models/resunet_ismir2021.py b/bytesep/models/resunet_ismir2021.py new file mode 100644 index 0000000000000000000000000000000000000000..effb0900e49f3078429b1442a35caca715d7ba3c --- /dev/null +++ b/bytesep/models/resunet_ismir2021.py @@ -0,0 +1,534 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from inplace_abn.abn import InPlaceABNSync +from torchlibrosa.stft import ISTFT, STFT, magphase + +from bytesep.models.pytorch_modules import Base, init_bn, init_layer + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, activation, momentum): + r"""Residual block.""" + super(ConvBlockRes, self).__init__() + + self.activation = activation + padding = [kernel_size[0] // 2, kernel_size[1] // 2] + + # ABN is not used for bn1 because we found using abn1 will degrade performance. + self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) + + self.abn2 = InPlaceABNSync( + num_features=out_channels, momentum=momentum, activation='leaky_relu' + ) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + if in_channels != out_channels: + self.shortcut = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + ) + self.is_shortcut = True + else: + self.is_shortcut = False + + self.init_weights() + + def init_weights(self): + init_bn(self.bn1) + init_layer(self.conv1) + init_layer(self.conv2) + + if self.is_shortcut: + init_layer(self.shortcut) + + def forward(self, x): + origin = x + x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01)) + x = self.conv2(self.abn2(x)) + + if self.is_shortcut: + return self.shortcut(origin) + x + else: + return origin + x + + +class EncoderBlockRes4B(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, downsample, activation, momentum + ): + r"""Encoder block, contains 8 convolutional layers.""" + super(EncoderBlockRes4B, self).__init__() + + self.conv_block1 = ConvBlockRes( + in_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block2 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block3 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block4 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.downsample = downsample + + def forward(self, x): + encoder = self.conv_block1(x) + encoder = self.conv_block2(encoder) + encoder = self.conv_block3(encoder) + encoder = self.conv_block4(encoder) + encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample) + return encoder_pool, encoder + + +class DecoderBlockRes4B(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, upsample, activation, momentum + ): + r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers.""" + super(DecoderBlockRes4B, self).__init__() + self.kernel_size = kernel_size + self.stride = upsample + self.activation = activation + + self.conv1 = torch.nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.stride, + stride=self.stride, + padding=(0, 0), + bias=False, + dilation=(1, 1), + ) + + self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) + self.conv_block2 = ConvBlockRes( + out_channels * 2, out_channels, kernel_size, activation, momentum + ) + self.conv_block3 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block4 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block5 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn1) + init_layer(self.conv1) + + def forward(self, input_tensor, concat_tensor): + x = self.conv1(F.relu_(self.bn1(input_tensor))) + x = torch.cat((x, concat_tensor), dim=1) + x = self.conv_block2(x) + x = self.conv_block3(x) + x = self.conv_block4(x) + x = self.conv_block5(x) + return x + + +class ResUNet143_DecouplePlusInplaceABN_ISMIR2021(nn.Module, Base): + def __init__(self, input_channels, target_sources_num): + super(ResUNet143_DecouplePlusInplaceABN_ISMIR2021, self).__init__() + + self.input_channels = input_channels + self.target_sources_num = target_sources_num + + window_size = 2048 + hop_size = 441 + center = True + pad_mode = 'reflect' + window = 'hann' + activation = 'leaky_relu' + momentum = 0.01 + + self.subbands_num = 1 + + assert ( + self.subbands_num == 1 + ), "Using subbands_num > 1 on spectrogram \ + will lead to unexpected performance sometimes. Suggest to use \ + subband method on waveform." + + # Downsample rate along the time axis. + self.K = 4 # outputs: |M|, cos∠M, sin∠M, Q + self.time_downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks} + + self.stft = STFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.istft = ISTFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum) + + self.encoder_block1 = EncoderBlockRes4B( + in_channels=input_channels * self.subbands_num, + out_channels=32, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block2 = EncoderBlockRes4B( + in_channels=32, + out_channels=64, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block3 = EncoderBlockRes4B( + in_channels=64, + out_channels=128, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block4 = EncoderBlockRes4B( + in_channels=128, + out_channels=256, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block5 = EncoderBlockRes4B( + in_channels=256, + out_channels=384, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block6 = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 2), + activation=activation, + momentum=momentum, + ) + self.conv_block7a = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.conv_block7b = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.conv_block7c = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.conv_block7d = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.decoder_block1 = DecoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(1, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block2 = DecoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block3 = DecoderBlockRes4B( + in_channels=384, + out_channels=256, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block4 = DecoderBlockRes4B( + in_channels=256, + out_channels=128, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block5 = DecoderBlockRes4B( + in_channels=128, + out_channels=64, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block6 = DecoderBlockRes4B( + in_channels=64, + out_channels=32, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + + self.after_conv_block1 = EncoderBlockRes4B( + in_channels=32, + out_channels=32, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + + self.after_conv2 = nn.Conv2d( + in_channels=32, + out_channels=target_sources_num + * input_channels + * self.K + * self.subbands_num, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn0) + init_layer(self.after_conv2) + + def feature_maps_to_wav( + self, + input_tensor: torch.Tensor, + sp: torch.Tensor, + sin_in: torch.Tensor, + cos_in: torch.Tensor, + audio_length: int, + ) -> torch.Tensor: + r"""Convert feature maps to waveform. + + Args: + input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins) + sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + + Outputs: + waveform: (batch_size, target_sources_num * input_channels, segment_samples) + """ + batch_size, _, time_steps, freq_bins = input_tensor.shape + + x = input_tensor.reshape( + batch_size, + self.target_sources_num, + self.input_channels, + self.K, + time_steps, + freq_bins, + ) + # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins) + + mask_mag = torch.sigmoid(x[:, :, :, 0, :, :]) + _mask_real = torch.tanh(x[:, :, :, 1, :, :]) + _mask_imag = torch.tanh(x[:, :, :, 2, :, :]) + _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag) + linear_mag = x[:, :, :, 3, :, :] + # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Y = |Y|cos∠Y + j|Y|sin∠Y + # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M) + # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M) + out_cos = ( + cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin + ) + out_sin = ( + sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin + ) + # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Calculate |Y|. + out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag) + # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Calculate Y_{real} and Y_{imag} for ISTFT. + out_real = out_mag * out_cos + out_imag = out_mag * out_sin + # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT. + shape = ( + batch_size * self.target_sources_num * self.input_channels, + 1, + time_steps, + freq_bins, + ) + out_real = out_real.reshape(shape) + out_imag = out_imag.reshape(shape) + + # ISTFT. + x = self.istft(out_real, out_imag, audio_length) + # (batch_size * target_sources_num * input_channels, segments_num) + + # Reshape. + waveform = x.reshape( + batch_size, self.target_sources_num * self.input_channels, audio_length + ) + # (batch_size, target_sources_num * input_channels, segments_num) + + return waveform + + def forward(self, input_dict): + r"""Forward data into the module. + + Args: + input_dict: dict, e.g., { + waveform: (batch_size, input_channels, segment_samples), + ..., + } + + Outputs: + output_dict: dict, e.g., { + 'waveform': (batch_size, input_channels, segment_samples), + ..., + } + """ + mixtures = input_dict['waveform'] + # (batch_size, input_channels, segment_samples) + + mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures) + # mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins) + + # Batch normalize on individual frequency bins. + x = mag.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + # x: (batch_size, input_channels, time_steps, freq_bins) + + # Pad spectrogram to be evenly divided by downsample ratio. + origin_len = x.shape[2] + pad_len = ( + int(np.ceil(x.shape[2] / self.time_downsample_ratio)) + * self.time_downsample_ratio + - origin_len + ) + x = F.pad(x, pad=(0, 0, 0, pad_len)) + # (batch_size, channels, padded_time_steps, freq_bins) + + # Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024. + x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F) + + if self.subbands_num > 1: + x = self.subband.analysis(x) + # (bs, input_channels, T, F'), where F' = F // subbands_num + + # UNet + (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2) + (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4) + (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8) + (x4_pool, x4) = self.encoder_block4( + x3_pool + ) # x4_pool: (bs, 256, T / 16, F / 16) + (x5_pool, x5) = self.encoder_block5( + x4_pool + ) # x5_pool: (bs, 384, T / 32, F / 32) + (x6_pool, x6) = self.encoder_block6( + x5_pool + ) # x6_pool: (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64) + x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32) + x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16) + x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8) + x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4) + x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2) + x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F) + (x, _) = self.after_conv_block1(x12) # (bs, 32, T, F) + + x = self.after_conv2(x) # (bs, channels * 3, T, F) + # (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F') + + if self.subbands_num > 1: + x = self.subband.synthesis(x) + # (batch_size, target_sources_num * input_channles * self.K, T, F) + + # Recover shape + x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025. + + x = x[:, :, 0:origin_len, :] + # (batch_size, target_sources_num * input_channles * self.K, T, F) + + audio_length = mixtures.shape[2] + + separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length) + # separated_audio: (batch_size, target_sources_num * input_channels, segments_num) + + output_dict = {'waveform': separated_audio} + + return output_dict diff --git a/bytesep/models/resunet_subbandtime.py b/bytesep/models/resunet_subbandtime.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ac3c5cd5aa2e3d49b7513b2e577ba148c80d7e --- /dev/null +++ b/bytesep/models/resunet_subbandtime.py @@ -0,0 +1,545 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import ISTFT, STFT, magphase + +from bytesep.models.pytorch_modules import Base, init_bn, init_layer +from bytesep.models.subband_tools.pqmf import PQMF + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, activation, momentum): + r"""Residual block.""" + super(ConvBlockRes, self).__init__() + + self.activation = activation + padding = [kernel_size[0] // 2, kernel_size[1] // 2] + + self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) + self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + if in_channels != out_channels: + self.shortcut = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + ) + self.is_shortcut = True + else: + self.is_shortcut = False + + self.init_weights() + + def init_weights(self): + init_bn(self.bn1) + init_bn(self.bn2) + init_layer(self.conv1) + init_layer(self.conv2) + + if self.is_shortcut: + init_layer(self.shortcut) + + def forward(self, x): + origin = x + x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01)) + x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01)) + + if self.is_shortcut: + return self.shortcut(origin) + x + else: + return origin + x + + +class EncoderBlockRes4B(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, downsample, activation, momentum + ): + r"""Encoder block, contains 8 convolutional layers.""" + super(EncoderBlockRes4B, self).__init__() + + self.conv_block1 = ConvBlockRes( + in_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block2 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block3 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block4 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.downsample = downsample + + def forward(self, x): + encoder = self.conv_block1(x) + encoder = self.conv_block2(encoder) + encoder = self.conv_block3(encoder) + encoder = self.conv_block4(encoder) + encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample) + return encoder_pool, encoder + + +class DecoderBlockRes4B(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, upsample, activation, momentum + ): + r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers.""" + super(DecoderBlockRes4B, self).__init__() + self.kernel_size = kernel_size + self.stride = upsample + self.activation = activation + + self.conv1 = torch.nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.stride, + stride=self.stride, + padding=(0, 0), + bias=False, + dilation=(1, 1), + ) + + self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum) + self.conv_block2 = ConvBlockRes( + out_channels * 2, out_channels, kernel_size, activation, momentum + ) + self.conv_block3 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block4 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + self.conv_block5 = ConvBlockRes( + out_channels, out_channels, kernel_size, activation, momentum + ) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn1) + init_layer(self.conv1) + + def forward(self, input_tensor, concat_tensor): + x = self.conv1(F.relu_(self.bn1(input_tensor))) + x = torch.cat((x, concat_tensor), dim=1) + x = self.conv_block2(x) + x = self.conv_block3(x) + x = self.conv_block4(x) + x = self.conv_block5(x) + return x + + +class ResUNet143_Subbandtime(nn.Module, Base): + def __init__(self, input_channels, target_sources_num): + super(ResUNet143_Subbandtime, self).__init__() + + self.input_channels = input_channels + self.target_sources_num = target_sources_num + + window_size = 512 + hop_size = 110 + center = True + pad_mode = "reflect" + window = "hann" + activation = "leaky_relu" + momentum = 0.01 + + self.subbands_num = 4 + self.K = 4 # outputs: |M|, cos∠M, sin∠M, Q + + self.downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks} + + self.pqmf = PQMF( + N=self.subbands_num, + M=64, + project_root='bytesep/models/subband_tools/filters', + ) + + self.stft = STFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.istft = ISTFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum) + + self.encoder_block1 = EncoderBlockRes4B( + in_channels=input_channels * self.subbands_num, + out_channels=32, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block2 = EncoderBlockRes4B( + in_channels=32, + out_channels=64, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block3 = EncoderBlockRes4B( + in_channels=64, + out_channels=128, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block4 = EncoderBlockRes4B( + in_channels=128, + out_channels=256, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block5 = EncoderBlockRes4B( + in_channels=256, + out_channels=384, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block6 = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 2), + activation=activation, + momentum=momentum, + ) + self.conv_block7a = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.conv_block7b = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.conv_block7c = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.conv_block7d = EncoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + self.decoder_block1 = DecoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(1, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block2 = DecoderBlockRes4B( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block3 = DecoderBlockRes4B( + in_channels=384, + out_channels=256, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block4 = DecoderBlockRes4B( + in_channels=256, + out_channels=128, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block5 = DecoderBlockRes4B( + in_channels=128, + out_channels=64, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block6 = DecoderBlockRes4B( + in_channels=64, + out_channels=32, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + + self.after_conv_block1 = EncoderBlockRes4B( + in_channels=32, + out_channels=32, + kernel_size=(3, 3), + downsample=(1, 1), + activation=activation, + momentum=momentum, + ) + + self.after_conv2 = nn.Conv2d( + in_channels=32, + out_channels=target_sources_num + * input_channels + * self.K + * self.subbands_num, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + + self.init_weights() + + def init_weights(self): + init_bn(self.bn0) + init_layer(self.after_conv2) + + def feature_maps_to_wav( + self, + input_tensor: torch.Tensor, + sp: torch.Tensor, + sin_in: torch.Tensor, + cos_in: torch.Tensor, + audio_length: int, + ) -> torch.Tensor: + r"""Convert feature maps to waveform. + + Args: + input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins) + sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + + Outputs: + waveform: (batch_size, target_sources_num * input_channels, segment_samples) + """ + batch_size, _, time_steps, freq_bins = input_tensor.shape + + x = input_tensor.reshape( + batch_size, + self.target_sources_num, + self.input_channels, + self.K, + time_steps, + freq_bins, + ) + # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins) + + mask_mag = torch.sigmoid(x[:, :, :, 0, :, :]) + _mask_real = torch.tanh(x[:, :, :, 1, :, :]) + _mask_imag = torch.tanh(x[:, :, :, 2, :, :]) + linear_mag = torch.tanh(x[:, :, :, 3, :, :]) + _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag) + # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Y = |Y|cos∠Y + j|Y|sin∠Y + # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M) + # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M) + out_cos = ( + cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin + ) + out_sin = ( + sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin + ) + # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Calculate |Y|. + out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag) + # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Calculate Y_{real} and Y_{imag} for ISTFT. + out_real = out_mag * out_cos + out_imag = out_mag * out_sin + # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT. + shape = ( + batch_size * self.target_sources_num * self.input_channels, + 1, + time_steps, + freq_bins, + ) + out_real = out_real.reshape(shape) + out_imag = out_imag.reshape(shape) + + # ISTFT. + x = self.istft(out_real, out_imag, audio_length) + # (batch_size * target_sources_num * input_channels, segments_num) + + # Reshape. + waveform = x.reshape( + batch_size, self.target_sources_num * self.input_channels, audio_length + ) + # (batch_size, target_sources_num * input_channels, segments_num) + + return waveform + + def forward(self, input_dict): + r"""Forward data into the module. + + Args: + input_dict: dict, e.g., { + waveform: (batch_size, input_channels, segment_samples), + ..., + } + + Outputs: + output_dict: dict, e.g., { + 'waveform': (batch_size, input_channels, segment_samples), + ..., + } + """ + mixtures = input_dict['waveform'] + # (batch_size, input_channels, segment_samples) + + subband_x = self.pqmf.analysis(mixtures) + # subband_x: (batch_size, input_channels * subbands_num, segment_samples) + + mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x) + # mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins) + + # Batch normalize on individual frequency bins. + x = mag.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + # (batch_size, input_channels * subbands_num, time_steps, freq_bins) + + # Pad spectrogram to be evenly divided by downsample ratio. + origin_len = x.shape[2] + pad_len = ( + int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio + - origin_len + ) + x = F.pad(x, pad=(0, 0, 0, pad_len)) + # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins) + + # Let frequency bins be evenly divided by 2, e.g., 257 -> 256 + x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F) + # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins) + + # UNet + (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2) + (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4) + (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8) + (x4_pool, x4) = self.encoder_block4( + x3_pool + ) # x4_pool: (bs, 256, T / 16, F / 16) + (x5_pool, x5) = self.encoder_block5( + x4_pool + ) # x5_pool: (bs, 384, T / 32, F / 32) + (x6_pool, x6) = self.encoder_block6( + x5_pool + ) # x6_pool: (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64) + (x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64) + x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32) + x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16) + x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8) + x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4) + x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2) + x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F) + (x, _) = self.after_conv_block1(x12) # (bs, 32, T, F) + + x = self.after_conv2(x) + # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F') + + # Recover shape + x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 256 -> 257. + + x = x[:, :, 0:origin_len, :] + # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F') + + audio_length = subband_x.shape[2] + + # Recover each subband spectrograms to subband waveforms. Then synthesis + # the subband waveforms to a waveform. + C1 = x.shape[1] // self.subbands_num + C2 = mag.shape[1] // self.subbands_num + + separated_subband_audio = torch.cat( + [ + self.feature_maps_to_wav( + input_tensor=x[:, j * C1 : (j + 1) * C1, :, :], + sp=mag[:, j * C2 : (j + 1) * C2, :, :], + sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :], + cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :], + audio_length=audio_length, + ) + for j in range(self.subbands_num) + ], + dim=1, + ) + # (batch_size, subbands_num * target_sources_num * input_channles, segment_samples) + + separated_audio = self.pqmf.synthesis(separated_subband_audio) + # (batch_size, input_channles, segment_samples) + + output_dict = {'waveform': separated_audio} + + return output_dict diff --git a/bytesep/models/subband_tools/__init__.py b/bytesep/models/subband_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bytesep/models/subband_tools/fDomainHelper.py b/bytesep/models/subband_tools/fDomainHelper.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5356888b441905536af88e65215f1238756e8f --- /dev/null +++ b/bytesep/models/subband_tools/fDomainHelper.py @@ -0,0 +1,255 @@ +from torchlibrosa.stft import STFT, ISTFT, magphase +import torch +import torch.nn as nn +import numpy as np +from tools.pytorch.modules.pqmf import PQMF + + +class FDomainHelper(nn.Module): + def __init__( + self, + window_size=2048, + hop_size=441, + center=True, + pad_mode='reflect', + window='hann', + freeze_parameters=True, + subband=None, + root="/Users/admin/Documents/projects/", + ): + super(FDomainHelper, self).__init__() + self.subband = subband + if self.subband is None: + self.stft = STFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=freeze_parameters, + ) + + self.istft = ISTFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=freeze_parameters, + ) + else: + self.stft = STFT( + n_fft=window_size // self.subband, + hop_length=hop_size // self.subband, + win_length=window_size // self.subband, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=freeze_parameters, + ) + + self.istft = ISTFT( + n_fft=window_size // self.subband, + hop_length=hop_size // self.subband, + win_length=window_size // self.subband, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=freeze_parameters, + ) + + if subband is not None and root is not None: + self.qmf = PQMF(subband, 64, root) + + def complex_spectrogram(self, input, eps=0.0): + # [batchsize, samples] + # return [batchsize, 2, t-steps, f-bins] + real, imag = self.stft(input) + return torch.cat([real, imag], dim=1) + + def reverse_complex_spectrogram(self, input, eps=0.0, length=None): + # [batchsize, 2[real,imag], t-steps, f-bins] + wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length) + return wav + + def spectrogram(self, input, eps=0.0): + (real, imag) = self.stft(input.float()) + return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 + + def spectrogram_phase(self, input, eps=0.0): + (real, imag) = self.stft(input.float()) + mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 + cos = real / mag + sin = imag / mag + return mag, cos, sin + + def wav_to_spectrogram_phase(self, input, eps=1e-8): + """Waveform to spectrogram. + + Args: + input: (batch_size, channels_num, segment_samples) + + Outputs: + output: (batch_size, channels_num, time_steps, freq_bins) + """ + sp_list = [] + cos_list = [] + sin_list = [] + channels_num = input.shape[1] + for channel in range(channels_num): + mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps) + sp_list.append(mag) + cos_list.append(cos) + sin_list.append(sin) + + sps = torch.cat(sp_list, dim=1) + coss = torch.cat(cos_list, dim=1) + sins = torch.cat(sin_list, dim=1) + return sps, coss, sins + + def spectrogram_phase_to_wav(self, sps, coss, sins, length): + channels_num = sps.size()[1] + res = [] + for i in range(channels_num): + res.append( + self.istft( + sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...], + sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...], + length, + ) + ) + res[-1] = res[-1].unsqueeze(1) + return torch.cat(res, dim=1) + + def wav_to_spectrogram(self, input, eps=1e-8): + """Waveform to spectrogram. + + Args: + input: (batch_size,channels_num, segment_samples) + + Outputs: + output: (batch_size, channels_num, time_steps, freq_bins) + """ + sp_list = [] + channels_num = input.shape[1] + for channel in range(channels_num): + sp_list.append(self.spectrogram(input[:, channel, :], eps=eps)) + output = torch.cat(sp_list, dim=1) + return output + + def spectrogram_to_wav(self, input, spectrogram, length=None): + """Spectrogram to waveform. + Args: + input: (batch_size, segment_samples, channels_num) + spectrogram: (batch_size, channels_num, time_steps, freq_bins) + + Outputs: + output: (batch_size, segment_samples, channels_num) + """ + channels_num = input.shape[1] + wav_list = [] + for channel in range(channels_num): + (real, imag) = self.stft(input[:, channel, :]) + (_, cos, sin) = magphase(real, imag) + wav_list.append( + self.istft( + spectrogram[:, channel : channel + 1, :, :] * cos, + spectrogram[:, channel : channel + 1, :, :] * sin, + length, + ) + ) + + output = torch.stack(wav_list, dim=1) + return output + + # todo the following code is not bug free! + def wav_to_complex_spectrogram(self, input, eps=0.0): + # [batchsize , channels, samples] + # [batchsize, 2[real,imag]*channels, t-steps, f-bins] + res = [] + channels_num = input.shape[1] + for channel in range(channels_num): + res.append(self.complex_spectrogram(input[:, channel, :], eps=eps)) + return torch.cat(res, dim=1) + + def complex_spectrogram_to_wav(self, input, eps=0.0, length=None): + # [batchsize, 2[real,imag]*channels, t-steps, f-bins] + # return [batchsize, channels, samples] + channels = input.size()[1] // 2 + wavs = [] + for i in range(channels): + wavs.append( + self.reverse_complex_spectrogram( + input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length + ) + ) + wavs[-1] = wavs[-1].unsqueeze(1) + return torch.cat(wavs, dim=1) + + def wav_to_complex_subband_spectrogram(self, input, eps=0.0): + # [batchsize, channels, samples] + # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] + subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples] + subspec = self.wav_to_complex_spectrogram(subwav) + return subspec + + def complex_subband_spectrogram_to_wav(self, input, eps=0.0): + # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] + # [batchsize, channels, samples] + subwav = self.complex_spectrogram_to_wav(input) + data = self.qmf.synthesis(subwav) + return data + + def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8): + """ + :param input: + :param eps: + :return: + loss = torch.nn.L1Loss() + model = FDomainHelper(subband=4) + data = torch.randn((3,1, 44100*3)) + + sps, coss, sins = model.wav_to_mag_phase_subband_spectrogram(data) + wav = model.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4) + + print(loss(data,wav)) + print(torch.max(torch.abs(data-wav))) + + """ + # [batchsize, channels, samples] + # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] + subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples] + sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps) + return sps, coss, sins + + def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0): + # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins] + # [batchsize, channels, samples] + subwav = self.spectrogram_phase_to_wav(sps, coss, sins, length) + data = self.qmf.synthesis(subwav) + return data + + +if __name__ == "__main__": + # from thop import profile + # from thop import clever_format + # from tools.file.wav import * + # import time + # + # wav = torch.randn((1,2,44100)) + # model = FDomainHelper() + + from tools.file.wav import * + + loss = torch.nn.L1Loss() + model = FDomainHelper() + data = torch.randn((3, 1, 44100 * 5)) + + sps = model.wav_to_complex_spectrogram(data) + print(sps.size()) + wav = model.complex_spectrogram_to_wav(sps, 44100 * 5) + + print(loss(data, wav)) + print(torch.max(torch.abs(data - wav))) diff --git a/bytesep/models/subband_tools/filters/f_4_64.mat b/bytesep/models/subband_tools/filters/f_4_64.mat new file mode 100644 index 0000000000000000000000000000000000000000..234e264863c19f3e82b7c1c16c1a2e127959a8fc Binary files /dev/null and b/bytesep/models/subband_tools/filters/f_4_64.mat differ diff --git a/bytesep/models/subband_tools/filters/h_4_64.mat b/bytesep/models/subband_tools/filters/h_4_64.mat new file mode 100644 index 0000000000000000000000000000000000000000..88ff6b8c7b7db856bee2a022764dabf036fa1a34 Binary files /dev/null and b/bytesep/models/subband_tools/filters/h_4_64.mat differ diff --git a/bytesep/models/subband_tools/pqmf.py b/bytesep/models/subband_tools/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..f282678ee0803c3721cca7008754516ac28f632a --- /dev/null +++ b/bytesep/models/subband_tools/pqmf.py @@ -0,0 +1,136 @@ +''' +@File : subband_util.py +@Contact : liu.8948@buckeyemail.osu.edu +@License : (C)Copyright 2020-2021 +@Modify Time @Author @Version @Desciption +------------ ------- -------- ----------- +2020/4/3 4:54 PM Haohe Liu 1.0 None +''' + +import torch +import torch.nn.functional as F +import torch.nn as nn +import numpy as np +import os.path as op +from scipy.io import loadmat + + +def load_mat2numpy(fname=""): + ''' + Args: + fname: pth to mat + type: + Returns: dic object + ''' + if len(fname) == 0: + return None + else: + return loadmat(fname) + + +class PQMF(nn.Module): + def __init__(self, N, M, project_root): + super().__init__() + self.N = N # nsubband + self.M = M # nfilter + try: + assert (N, M) in [(8, 64), (4, 64), (2, 64)] + except: + print("Warning:", N, "subbandand ", M, " filter is not supported") + self.pad_samples = 64 + self.name = str(N) + "_" + str(M) + ".mat" + self.ana_conv_filter = nn.Conv1d( + 1, out_channels=N, kernel_size=M, stride=N, bias=False + ) + data = load_mat2numpy(op.join(project_root, "f_" + self.name)) + data = data['f'].astype(np.float32) / N + data = np.flipud(data.T).T + data = np.reshape(data, (N, 1, M)).copy() + dict_new = self.ana_conv_filter.state_dict().copy() + dict_new['weight'] = torch.from_numpy(data) + self.ana_pad = nn.ConstantPad1d((M - N, 0), 0) + self.ana_conv_filter.load_state_dict(dict_new) + + self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0) + self.syn_conv_filter = nn.Conv1d( + N, out_channels=N, kernel_size=M // N, stride=1, bias=False + ) + gk = load_mat2numpy(op.join(project_root, "h_" + self.name)) + gk = gk['h'].astype(np.float32) + gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N + gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy() + dict_new = self.syn_conv_filter.state_dict().copy() + dict_new['weight'] = torch.from_numpy(gk) + self.syn_conv_filter.load_state_dict(dict_new) + + for param in self.parameters(): + param.requires_grad = False + + def __analysis_channel(self, inputs): + return self.ana_conv_filter(self.ana_pad(inputs)) + + def __systhesis_channel(self, inputs): + ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1) + return torch.reshape(ret, (ret.shape[0], 1, -1)) + + def analysis(self, inputs): + ''' + :param inputs: [batchsize,channel,raw_wav],value:[0,1] + :return: + ''' + inputs = F.pad(inputs, ((0, self.pad_samples))) + ret = None + for i in range(inputs.size()[1]): # channels + if ret is None: + ret = self.__analysis_channel(inputs[:, i : i + 1, :]) + else: + ret = torch.cat( + (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1 + ) + return ret + + def synthesis(self, data): + ''' + :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1] + :return: + ''' + ret = None + # data = F.pad(data,((0,self.pad_samples//self.N))) + for i in range(data.size()[1]): # channels + if i % self.N == 0: + if ret is None: + ret = self.__systhesis_channel(data[:, i : i + self.N, :]) + else: + new = self.__systhesis_channel(data[:, i : i + self.N, :]) + ret = torch.cat((ret, new), dim=1) + ret = ret[..., : -self.pad_samples] + return ret + + def forward(self, inputs): + return self.ana_conv_filter(self.ana_pad(inputs)) + + +if __name__ == "__main__": + import torch + import numpy as np + import matplotlib.pyplot as plt + from tools.file.wav import * + + pqmf = PQMF(N=4, M=64, project_root="/Users/admin/Documents/projects") + + rs = np.random.RandomState(0) + x = torch.tensor(rs.rand(4, 2, 32000), dtype=torch.float32) + + a1 = pqmf.analysis(x) + a2 = pqmf.synthesis(a1) + + print(a2.size(), x.size()) + + plt.subplot(211) + plt.plot(x[0, 0, -500:]) + plt.subplot(212) + plt.plot(a2[0, 0, -500:]) + plt.plot(x[0, 0, -500:] - a2[0, 0, -500:]) + plt.show() + + print(torch.sum(torch.abs(x[...] - a2[...]))) diff --git a/bytesep/models/unet.py b/bytesep/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ffb5f0879b13aa37f7a14c2b1ce3a90271fb96 --- /dev/null +++ b/bytesep/models/unet.py @@ -0,0 +1,532 @@ +import math +from typing import Dict, List, NoReturn, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import LambdaLR +from torchlibrosa.stft import ISTFT, STFT, magphase + +from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple, + activation: str, + momentum: float, + ): + r"""Convolutional block.""" + super(ConvBlock, self).__init__() + + self.activation = activation + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=(1, 1), + dilation=(1, 1), + padding=padding, + bias=False, + ) + + self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) + + self.init_weights() + + def init_weights(self) -> NoReturn: + r"""Initialize weights.""" + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + r"""Forward data into the module. + + Args: + input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins) + + Returns: + output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins) + """ + x = act(self.bn1(self.conv1(input_tensor)), self.activation) + x = act(self.bn2(self.conv2(x)), self.activation) + output_tensor = x + + return output_tensor + + +class EncoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple, + downsample: Tuple, + activation: str, + momentum: float, + ): + r"""Encoder block.""" + super(EncoderBlock, self).__init__() + + self.conv_block = ConvBlock( + in_channels, out_channels, kernel_size, activation, momentum + ) + self.downsample = downsample + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + r"""Forward data into the module. + + Args: + input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins) + + Returns: + encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins) + encoder: (batch_size, out_feature_maps, time_steps, freq_bins) + """ + encoder_tensor = self.conv_block(input_tensor) + # encoder: (batch_size, out_feature_maps, time_steps, freq_bins) + + encoder_pool = F.avg_pool2d(encoder_tensor, kernel_size=self.downsample) + # encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins) + + return encoder_pool, encoder_tensor + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple, + upsample: Tuple, + activation: str, + momentum: float, + ): + r"""Decoder block.""" + super(DecoderBlock, self).__init__() + + self.kernel_size = kernel_size + self.stride = upsample + self.activation = activation + + self.conv1 = torch.nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.stride, + stride=self.stride, + padding=(0, 0), + bias=False, + dilation=(1, 1), + ) + + self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum) + + self.conv_block2 = ConvBlock( + out_channels * 2, out_channels, kernel_size, activation, momentum + ) + + self.init_weights() + + def init_weights(self): + r"""Initialize weights.""" + init_layer(self.conv1) + init_bn(self.bn1) + + def forward( + self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor + ) -> torch.Tensor: + r"""Forward data into the module. + + Args: + torch_tensor: (batch_size, in_feature_maps, downsampled_time_steps, downsampled_freq_bins) + concat_tensor: (batch_size, in_feature_maps, time_steps, freq_bins) + + Returns: + output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins) + """ + x = act(self.bn1(self.conv1(input_tensor)), self.activation) + # (batch_size, in_feature_maps, time_steps, freq_bins) + + x = torch.cat((x, concat_tensor), dim=1) + # (batch_size, in_feature_maps * 2, time_steps, freq_bins) + + output_tensor = self.conv_block2(x) + # output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins) + + return output_tensor + + +class UNet(nn.Module, Base): + def __init__(self, input_channels: int, target_sources_num: int): + r"""UNet.""" + super(UNet, self).__init__() + + self.input_channels = input_channels + self.target_sources_num = target_sources_num + + window_size = 2048 + hop_size = 441 + center = True + pad_mode = "reflect" + window = "hann" + activation = "leaky_relu" + momentum = 0.01 + + self.subbands_num = 1 + + assert ( + self.subbands_num == 1 + ), "Using subbands_num > 1 on spectrogram \ + will lead to unexpected performance sometimes. Suggest to use \ + subband method on waveform." + + self.K = 3 # outputs: |M|, cos∠M, sin∠M + self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks} + + self.stft = STFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.istft = ISTFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum) + + self.subband = Subband(subbands_num=self.subbands_num) + + self.encoder_block1 = EncoderBlock( + in_channels=input_channels * self.subbands_num, + out_channels=32, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block2 = EncoderBlock( + in_channels=32, + out_channels=64, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block3 = EncoderBlock( + in_channels=64, + out_channels=128, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block4 = EncoderBlock( + in_channels=128, + out_channels=256, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block5 = EncoderBlock( + in_channels=256, + out_channels=384, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block6 = EncoderBlock( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.conv_block7 = ConvBlock( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + activation=activation, + momentum=momentum, + ) + self.decoder_block1 = DecoderBlock( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block2 = DecoderBlock( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block3 = DecoderBlock( + in_channels=384, + out_channels=256, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block4 = DecoderBlock( + in_channels=256, + out_channels=128, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block5 = DecoderBlock( + in_channels=128, + out_channels=64, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + + self.decoder_block6 = DecoderBlock( + in_channels=64, + out_channels=32, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + + self.after_conv_block1 = ConvBlock( + in_channels=32, + out_channels=32, + kernel_size=(3, 3), + activation=activation, + momentum=momentum, + ) + + self.after_conv2 = nn.Conv2d( + in_channels=32, + out_channels=target_sources_num + * input_channels + * self.K + * self.subbands_num, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + + self.init_weights() + + def init_weights(self): + r"""Initialize weights.""" + init_bn(self.bn0) + init_layer(self.after_conv2) + + def feature_maps_to_wav( + self, + input_tensor: torch.Tensor, + sp: torch.Tensor, + sin_in: torch.Tensor, + cos_in: torch.Tensor, + audio_length: int, + ) -> torch.Tensor: + r"""Convert feature maps to waveform. + + Args: + input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins) + sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + + Outputs: + waveform: (batch_size, target_sources_num * input_channels, segment_samples) + """ + batch_size, _, time_steps, freq_bins = input_tensor.shape + + x = input_tensor.reshape( + batch_size, + self.target_sources_num, + self.input_channels, + self.K, + time_steps, + freq_bins, + ) + # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins) + + mask_mag = torch.sigmoid(x[:, :, :, 0, :, :]) + _mask_real = torch.tanh(x[:, :, :, 1, :, :]) + _mask_imag = torch.tanh(x[:, :, :, 2, :, :]) + _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag) + # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Y = |Y|cos∠Y + j|Y|sin∠Y + # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M) + # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M) + out_cos = ( + cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin + ) + out_sin = ( + sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin + ) + # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Calculate |Y|. + out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag) + # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Calculate Y_{real} and Y_{imag} for ISTFT. + out_real = out_mag * out_cos + out_imag = out_mag * out_sin + # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT. + shape = ( + batch_size * self.target_sources_num * self.input_channels, + 1, + time_steps, + freq_bins, + ) + out_real = out_real.reshape(shape) + out_imag = out_imag.reshape(shape) + + # ISTFT. + x = self.istft(out_real, out_imag, audio_length) + # (batch_size * target_sources_num * input_channels, segments_num) + + # Reshape. + waveform = x.reshape( + batch_size, self.target_sources_num * self.input_channels, audio_length + ) + # (batch_size, target_sources_num * input_channels, segments_num) + + return waveform + + def forward(self, input_dict: Dict) -> Dict: + r"""Forward data into the module. + + Args: + input_dict: dict, e.g., { + waveform: (batch_size, input_channels, segment_samples), + ..., + } + + Outputs: + output_dict: dict, e.g., { + 'waveform': (batch_size, input_channels, segment_samples), + ..., + } + """ + mixtures = input_dict['waveform'] + # (batch_size, input_channels, segment_samples) + + mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures) + # mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins) + + # Batch normalize on individual frequency bins. + x = mag.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + # x: (batch_size, input_channels, time_steps, freq_bins) + + # Pad spectrogram to be evenly divided by downsample ratio. + origin_len = x.shape[2] + pad_len = ( + int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio + - origin_len + ) + x = F.pad(x, pad=(0, 0, 0, pad_len)) + # x: (batch_size, input_channels, padded_time_steps, freq_bins) + + # Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024 + x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F) + + if self.subbands_num > 1: + x = self.subband.analysis(x) + # (bs, input_channels, T, F'), where F' = F // subbands_num + + # UNet + (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2) + (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4) + (x3_pool, x3) = self.encoder_block3( + x2_pool + ) # x3_pool: (bs, 128, T / 8, F' / 8) + (x4_pool, x4) = self.encoder_block4( + x3_pool + ) # x4_pool: (bs, 256, T / 16, F' / 16) + (x5_pool, x5) = self.encoder_block5( + x4_pool + ) # x5_pool: (bs, 384, T / 32, F' / 32) + (x6_pool, x6) = self.encoder_block6( + x5_pool + ) # x6_pool: (bs, 384, T / 64, F' / 64) + x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64) + x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32) + x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16) + x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8) + x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4) + x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2) + x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F') + x = self.after_conv_block1(x12) # (bs, 32, T, F') + + x = self.after_conv2(x) + # (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F') + + if self.subbands_num > 1: + x = self.subband.synthesis(x) + # (batch_size, target_sources_num * input_channles * self.K, T, F) + + # Recover shape + x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025. + + x = x[:, :, 0:origin_len, :] + # (batch_size, target_sources_num * input_channles * self.K, T, F) + + audio_length = mixtures.shape[2] + + separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length) + # separated_audio: (batch_size, target_sources_num * input_channels, segments_num) + + output_dict = {'waveform': separated_audio} + + return output_dict diff --git a/bytesep/models/unet_subbandtime.py b/bytesep/models/unet_subbandtime.py new file mode 100644 index 0000000000000000000000000000000000000000..e905b73eafb547d549f11901f79c09c6c8cb30d3 --- /dev/null +++ b/bytesep/models/unet_subbandtime.py @@ -0,0 +1,389 @@ +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import ISTFT, STFT, magphase + +from bytesep.models.pytorch_modules import Base, init_bn, init_layer +from bytesep.models.subband_tools.pqmf import PQMF +from bytesep.models.unet import ConvBlock, DecoderBlock, EncoderBlock + + +class UNetSubbandTime(nn.Module, Base): + def __init__(self, input_channels: int, target_sources_num: int): + r"""Subband waveform UNet.""" + super(UNetSubbandTime, self).__init__() + + self.input_channels = input_channels + self.target_sources_num = target_sources_num + + window_size = 512 # 2048 // 4 + hop_size = 110 # 441 // 4 + center = True + pad_mode = "reflect" + window = "hann" + activation = "leaky_relu" + momentum = 0.01 + + self.subbands_num = 4 + self.K = 3 # outputs: |M|, cos∠M, sin∠M + + self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks} + + self.pqmf = PQMF( + N=self.subbands_num, + M=64, + project_root='bytesep/models/subband_tools/filters', + ) + + self.stft = STFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.istft = ISTFT( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum) + + self.encoder_block1 = EncoderBlock( + in_channels=input_channels * self.subbands_num, + out_channels=32, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block2 = EncoderBlock( + in_channels=32, + out_channels=64, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block3 = EncoderBlock( + in_channels=64, + out_channels=128, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block4 = EncoderBlock( + in_channels=128, + out_channels=256, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block5 = EncoderBlock( + in_channels=256, + out_channels=384, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.encoder_block6 = EncoderBlock( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + downsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.conv_block7 = ConvBlock( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + activation=activation, + momentum=momentum, + ) + self.decoder_block1 = DecoderBlock( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block2 = DecoderBlock( + in_channels=384, + out_channels=384, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block3 = DecoderBlock( + in_channels=384, + out_channels=256, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block4 = DecoderBlock( + in_channels=256, + out_channels=128, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + self.decoder_block5 = DecoderBlock( + in_channels=128, + out_channels=64, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + + self.decoder_block6 = DecoderBlock( + in_channels=64, + out_channels=32, + kernel_size=(3, 3), + upsample=(2, 2), + activation=activation, + momentum=momentum, + ) + + self.after_conv_block1 = ConvBlock( + in_channels=32, + out_channels=32, + kernel_size=(3, 3), + activation=activation, + momentum=momentum, + ) + + self.after_conv2 = nn.Conv2d( + in_channels=32, + out_channels=target_sources_num + * input_channels + * self.K + * self.subbands_num, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + + self.init_weights() + + def init_weights(self): + r"""Initialize weights.""" + init_bn(self.bn0) + init_layer(self.after_conv2) + + def feature_maps_to_wav( + self, + input_tensor: torch.Tensor, + sp: torch.Tensor, + sin_in: torch.Tensor, + cos_in: torch.Tensor, + audio_length: int, + ) -> torch.Tensor: + r"""Convert feature maps to waveform. + + Args: + input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins) + sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins) + + Outputs: + waveform: (batch_size, target_sources_num * input_channels, segment_samples) + """ + batch_size, _, time_steps, freq_bins = input_tensor.shape + + x = input_tensor.reshape( + batch_size, + self.target_sources_num, + self.input_channels, + self.K, + time_steps, + freq_bins, + ) + # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins) + + mask_mag = torch.sigmoid(x[:, :, :, 0, :, :]) + _mask_real = torch.tanh(x[:, :, :, 1, :, :]) + _mask_imag = torch.tanh(x[:, :, :, 2, :, :]) + _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag) + # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Y = |Y|cos∠Y + j|Y|sin∠Y + # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M) + # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M) + out_cos = ( + cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin + ) + out_sin = ( + sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin + ) + # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Calculate |Y|. + out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag) + # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Calculate Y_{real} and Y_{imag} for ISTFT. + out_real = out_mag * out_cos + out_imag = out_mag * out_sin + # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins) + + # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT. + shape = ( + batch_size * self.target_sources_num * self.input_channels, + 1, + time_steps, + freq_bins, + ) + out_real = out_real.reshape(shape) + out_imag = out_imag.reshape(shape) + + # ISTFT. + x = self.istft(out_real, out_imag, audio_length) + # (batch_size * target_sources_num * input_channels, segments_num) + + # Reshape. + waveform = x.reshape( + batch_size, self.target_sources_num * self.input_channels, audio_length + ) + # (batch_size, target_sources_num * input_channels, segments_num) + + return waveform + + def forward(self, input_dict: Dict) -> Dict: + """Forward data into the module. + + Args: + input_dict: dict, e.g., { + waveform: (batch_size, input_channels, segment_samples), + ..., + } + + Outputs: + output_dict: dict, e.g., { + 'waveform': (batch_size, input_channels, segment_samples), + ..., + } + """ + mixtures = input_dict['waveform'] + # (batch_size, input_channels, segment_samples) + + if self.subbands_num > 1: + subband_x = self.pqmf.analysis(mixtures) + # -- subband_x: (batch_size, input_channels * subbands_num, segment_samples) + # -- subband_x: (batch_size, subbands_num * input_channels, segment_samples) + else: + subband_x = mixtures + + # from IPython import embed; embed(using=False); os._exit(0) + # import soundfile + # soundfile.write(file='_zz.wav', data=subband_x.data.cpu().numpy()[0, 2], samplerate=11025) + + mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x) + # mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins) + + # Batch normalize on individual frequency bins. + x = mag.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + # (batch_size, input_channels * subbands_num, time_steps, freq_bins) + + # Pad spectrogram to be evenly divided by downsample ratio. + origin_len = x.shape[2] + pad_len = ( + int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio + - origin_len + ) + x = F.pad(x, pad=(0, 0, 0, pad_len)) + # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins) + + # Let frequency bins be evenly divided by 2, e.g., 257 -> 256 + x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F) + # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins) + + # UNet + (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2) + (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4) + (x3_pool, x3) = self.encoder_block3( + x2_pool + ) # x3_pool: (bs, 128, T / 8, F' / 8) + (x4_pool, x4) = self.encoder_block4( + x3_pool + ) # x4_pool: (bs, 256, T / 16, F' / 16) + (x5_pool, x5) = self.encoder_block5( + x4_pool + ) # x5_pool: (bs, 384, T / 32, F' / 32) + (x6_pool, x6) = self.encoder_block6( + x5_pool + ) # x6_pool: (bs, 384, T / 64, F' / 64) + x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64) + x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32) + x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16) + x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8) + x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4) + x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2) + x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F') + x = self.after_conv_block1(x12) # (bs, 32, T, F') + + x = self.after_conv2(x) + # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F') + + # Recover shape + x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 256 -> 257. + + x = x[:, :, 0:origin_len, :] + # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F') + + audio_length = subband_x.shape[2] + + # Recover each subband spectrograms to subband waveforms. Then synthesis + # the subband waveforms to a waveform. + C1 = x.shape[1] // self.subbands_num + C2 = mag.shape[1] // self.subbands_num + + separated_subband_audio = torch.cat( + [ + self.feature_maps_to_wav( + input_tensor=x[:, j * C1 : (j + 1) * C1, :, :], + sp=mag[:, j * C2 : (j + 1) * C2, :, :], + sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :], + cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :], + audio_length=audio_length, + ) + for j in range(self.subbands_num) + ], + dim=1, + ) + # (batch_size, subbands_num * target_sources_num * input_channles, segment_samples) + + if self.subbands_num > 1: + separated_audio = self.pqmf.synthesis(separated_subband_audio) + # (batch_size, target_sources_num * input_channles, segment_samples) + else: + separated_audio = separated_subband_audio + + output_dict = {'waveform': separated_audio} + + return output_dict diff --git a/bytesep/optimizers/__init__.py b/bytesep/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bytesep/optimizers/lr_schedulers.py b/bytesep/optimizers/lr_schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..018dafa9275bcfe03bea92c64aee25a5db996b8f --- /dev/null +++ b/bytesep/optimizers/lr_schedulers.py @@ -0,0 +1,20 @@ +def get_lr_lambda(step, warm_up_steps: int, reduce_lr_steps: int): + r"""Get lr_lambda for LambdaLR. E.g., + + .. code-block: python + lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000) + + from torch.optim.lr_scheduler import LambdaLR + LambdaLR(optimizer, lr_lambda) + + Args: + warm_up_steps: int, steps for warm up + reduce_lr_steps: int, reduce learning rate by 0.9 every #reduce_lr_steps steps + + Returns: + learning rate: float + """ + if step <= warm_up_steps: + return step / warm_up_steps + else: + return 0.9 ** (step // reduce_lr_steps) diff --git a/bytesep/plot_results/__init__.py b/bytesep/plot_results/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bytesep/plot_results/musdb18.py b/bytesep/plot_results/musdb18.py new file mode 100644 index 0000000000000000000000000000000000000000..eb91faa60b79f0f34aba1bb4810c2be7be8438f3 --- /dev/null +++ b/bytesep/plot_results/musdb18.py @@ -0,0 +1,198 @@ +import argparse +import os +import pickle + +import matplotlib.pyplot as plt +import numpy as np + + +def load_sdrs(workspace, task_name, filename, config, gpus, source_type): + + stat_path = os.path.join( + workspace, + "statistics", + task_name, + filename, + "config={},gpus={}".format(config, gpus), + "statistics.pkl", + ) + + stat_dict = pickle.load(open(stat_path, 'rb')) + + median_sdrs = [e['median_sdr_dict'][source_type] for e in stat_dict['test']] + + return median_sdrs + + +def plot_statistics(args): + + # arguments & parameters + workspace = args.workspace + select = args.select + task_name = "musdb18" + filename = "train" + + # paths + fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select)) + os.makedirs(os.path.dirname(fig_path), exist_ok=True) + + linewidth = 1 + lines = [] + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + + if select == '1a': + sdrs = load_sdrs( + workspace, + task_name, + filename, + config='vocals-accompaniment,unet', + gpus=1, + source_type="vocals", + ) + (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) + lines.append(line) + ylim = 15 + + elif select == '1b': + sdrs = load_sdrs( + workspace, + task_name, + filename, + config='accompaniment-vocals,unet', + gpus=1, + source_type="accompaniment", + ) + (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) + lines.append(line) + ylim = 20 + + if select == '1c': + sdrs = load_sdrs( + workspace, + task_name, + filename, + config='vocals-accompaniment,unet', + gpus=1, + source_type="vocals", + ) + (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) + lines.append(line) + + sdrs = load_sdrs( + workspace, + task_name, + filename, + config='vocals-accompaniment,resunet', + gpus=2, + source_type="vocals", + ) + (line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth) + lines.append(line) + + sdrs = load_sdrs( + workspace, + task_name, + filename, + config='vocals-accompaniment,unet_subbandtime', + gpus=1, + source_type="vocals", + ) + (line,) = ax.plot(sdrs, label='unet_subband,l1_wav', linewidth=linewidth) + lines.append(line) + + sdrs = load_sdrs( + workspace, + task_name, + filename, + config='vocals-accompaniment,resunet_subbandtime', + gpus=1, + source_type="vocals", + ) + (line,) = ax.plot(sdrs, label='resunet_subband,l1_wav', linewidth=linewidth) + lines.append(line) + + ylim = 15 + + elif select == '1d': + sdrs = load_sdrs( + workspace, + task_name, + filename, + config='accompaniment-vocals,unet', + gpus=1, + source_type="accompaniment", + ) + (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) + lines.append(line) + + sdrs = load_sdrs( + workspace, + task_name, + filename, + config='accompaniment-vocals,resunet', + gpus=2, + source_type="accompaniment", + ) + (line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth) + lines.append(line) + + # sdrs = load_sdrs( + # workspace, + # task_name, + # filename, + # config='accompaniment-vocals,unet_subbandtime', + # gpus=1, + # source_type="accompaniment", + # ) + # (line,) = ax.plot(sdrs, label='UNet_subbtandtime,l1_wav', linewidth=linewidth) + # lines.append(line) + + sdrs = load_sdrs( + workspace, + task_name, + filename, + config='accompaniment-vocals,resunet_subbandtime', + gpus=1, + source_type="accompaniment", + ) + (line,) = ax.plot( + sdrs, label='ResUNet_subbtandtime,l1_wav', linewidth=linewidth + ) + lines.append(line) + + ylim = 20 + + else: + raise Exception('Error!') + + eval_every_iterations = 10000 + total_ticks = 50 + ticks_freq = 10 + + ax.set_ylim(0, ylim) + ax.set_xlim(0, total_ticks) + ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq)) + ax.xaxis.set_ticklabels( + np.arange( + 0, + total_ticks * eval_every_iterations + 1, + ticks_freq * eval_every_iterations, + ) + ) + ax.yaxis.set_ticks(np.arange(ylim + 1)) + ax.yaxis.set_ticklabels(np.arange(ylim + 1)) + ax.grid(color='b', linestyle='solid', linewidth=0.3) + plt.legend(handles=lines, loc=4) + + plt.savefig(fig_path) + print('Save figure to {}'.format(fig_path)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--workspace', type=str, required=True) + parser.add_argument('--select', type=str, required=True) + + args = parser.parse_args() + + plot_statistics(args) diff --git a/bytesep/plot_results/plot_vctk-musdb18.py b/bytesep/plot_results/plot_vctk-musdb18.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cc52af1e20b8f051bbe0dd8fd30c60a0dbf587 --- /dev/null +++ b/bytesep/plot_results/plot_vctk-musdb18.py @@ -0,0 +1,87 @@ +import os +import sys +import numpy as np +import argparse +import h5py +import math +import time +import logging +import pickle +import matplotlib.pyplot as plt + + +def load_sdrs(workspace, task_name, filename, config, gpus): + + stat_path = os.path.join( + workspace, + "statistics", + task_name, + filename, + "config={},gpus={}".format(config, gpus), + "statistics.pkl", + ) + + stat_dict = pickle.load(open(stat_path, 'rb')) + + median_sdrs = [e['sdr'] for e in stat_dict['test']] + + return median_sdrs + + +def plot_statistics(args): + + # arguments & parameters + workspace = args.workspace + select = args.select + task_name = "vctk-musdb18" + filename = "train" + + # paths + fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select)) + os.makedirs(os.path.dirname(fig_path), exist_ok=True) + + linewidth = 1 + lines = [] + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ylim = 30 + expand = 1 + + if select == '1a': + sdrs = load_sdrs(workspace, task_name, filename, config='unet', gpus=1) + (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) + lines.append(line) + + else: + raise Exception('Error!') + + eval_every_iterations = 10000 + total_ticks = 50 + ticks_freq = 10 + + ax.set_ylim(0, ylim) + ax.set_xlim(0, total_ticks) + ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq)) + ax.xaxis.set_ticklabels( + np.arange( + 0, + total_ticks * eval_every_iterations + 1, + ticks_freq * eval_every_iterations, + ) + ) + ax.yaxis.set_ticks(np.arange(ylim + 1)) + ax.yaxis.set_ticklabels(np.arange(ylim + 1)) + ax.grid(color='b', linestyle='solid', linewidth=0.3) + plt.legend(handles=lines, loc=4) + + plt.savefig(fig_path) + print('Save figure to {}'.format(fig_path)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--workspace', type=str, required=True) + parser.add_argument('--select', type=str, required=True) + + args = parser.parse_args() + + plot_statistics(args) diff --git a/bytesep/train.py b/bytesep/train.py new file mode 100644 index 0000000000000000000000000000000000000000..bf4f6fb4d815bb791b7d578aca055124e495bb93 --- /dev/null +++ b/bytesep/train.py @@ -0,0 +1,299 @@ +import argparse +import logging +import os +import pathlib +from functools import partial +from typing import List, NoReturn + +import pytorch_lightning as pl +from pytorch_lightning.plugins import DDPPlugin + +from bytesep.callbacks import get_callbacks +from bytesep.data.augmentors import Augmentor +from bytesep.data.batch_data_preprocessors import ( + get_batch_data_preprocessor_class, +) +from bytesep.data.data_modules import DataModule, Dataset +from bytesep.data.samplers import SegmentSampler +from bytesep.losses import get_loss_function +from bytesep.models.lightning_modules import ( + LitSourceSeparation, + get_model_class, +) +from bytesep.optimizers.lr_schedulers import get_lr_lambda +from bytesep.utils import ( + create_logging, + get_pitch_shift_factor, + read_yaml, + check_configs_gramma, +) + + +def get_dirs( + workspace: str, task_name: str, filename: str, config_yaml: str, gpus: int +) -> List[str]: + r"""Get directories. + + Args: + workspace: str + task_name, str, e.g., 'musdb18' + filenmae: str + config_yaml: str + gpus: int, e.g., 0 for cpu and 8 for training with 8 gpu cards + + Returns: + checkpoints_dir: str + logs_dir: str + logger: pl.loggers.TensorBoardLogger + statistics_path: str + """ + + # save checkpoints dir + checkpoints_dir = os.path.join( + workspace, + "checkpoints", + task_name, + filename, + "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), + ) + os.makedirs(checkpoints_dir, exist_ok=True) + + # logs dir + logs_dir = os.path.join( + workspace, + "logs", + task_name, + filename, + "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), + ) + os.makedirs(logs_dir, exist_ok=True) + + # loggings + create_logging(logs_dir, filemode='w') + logging.info(args) + + # tensorboard logs dir + tb_logs_dir = os.path.join(workspace, "tensorboard_logs") + os.makedirs(tb_logs_dir, exist_ok=True) + + experiment_name = os.path.join(task_name, filename, pathlib.Path(config_yaml).stem) + logger = pl.loggers.TensorBoardLogger(save_dir=tb_logs_dir, name=experiment_name) + + # statistics path + statistics_path = os.path.join( + workspace, + "statistics", + task_name, + filename, + "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), + "statistics.pkl", + ) + os.makedirs(os.path.dirname(statistics_path), exist_ok=True) + + return checkpoints_dir, logs_dir, logger, statistics_path + + +def _get_data_module( + workspace: str, config_yaml: str, num_workers: int, distributed: bool +) -> DataModule: + r"""Create data_module. Mini-batch data can be obtained by: + + code-block:: python + + data_module.setup() + for batch_data_dict in data_module.train_dataloader(): + print(batch_data_dict.keys()) + break + + Args: + workspace: str + config_yaml: str + num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores + for preparing data in parallel + distributed: bool + + Returns: + data_module: DataModule + """ + + configs = read_yaml(config_yaml) + input_source_types = configs['train']['input_source_types'] + indexes_path = os.path.join(workspace, configs['train']['indexes_dict']) + sample_rate = configs['train']['sample_rate'] + segment_seconds = configs['train']['segment_seconds'] + mixaudio_dict = configs['train']['augmentations']['mixaudio'] + augmentations = configs['train']['augmentations'] + max_pitch_shift = max( + [ + augmentations['pitch_shift'][source_type] + for source_type in input_source_types + ] + ) + batch_size = configs['train']['batch_size'] + steps_per_epoch = configs['train']['steps_per_epoch'] + + segment_samples = int(segment_seconds * sample_rate) + ex_segment_samples = int(segment_samples * get_pitch_shift_factor(max_pitch_shift)) + + # sampler + train_sampler = SegmentSampler( + indexes_path=indexes_path, + segment_samples=ex_segment_samples, + mixaudio_dict=mixaudio_dict, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + ) + + # augmentor + augmentor = Augmentor(augmentations=augmentations) + + # dataset + train_dataset = Dataset(augmentor, segment_samples) + + # data module + data_module = DataModule( + train_sampler=train_sampler, + train_dataset=train_dataset, + num_workers=num_workers, + distributed=distributed, + ) + + return data_module + + +def train(args) -> NoReturn: + r"""Train & evaluate and save checkpoints. + + Args: + workspace: str, directory of workspace + gpus: int + config_yaml: str, path of config file for training + """ + + # arugments & parameters + workspace = args.workspace + gpus = args.gpus + config_yaml = args.config_yaml + filename = args.filename + + num_workers = 8 + distributed = True if gpus > 1 else False + evaluate_device = "cuda" if gpus > 0 else "cpu" + + # Read config file. + configs = read_yaml(config_yaml) + check_configs_gramma(configs) + task_name = configs['task_name'] + target_source_types = configs['train']['target_source_types'] + target_sources_num = len(target_source_types) + channels = configs['train']['channels'] + batch_data_preprocessor_type = configs['train']['batch_data_preprocessor'] + model_type = configs['train']['model_type'] + loss_type = configs['train']['loss_type'] + optimizer_type = configs['train']['optimizer_type'] + learning_rate = float(configs['train']['learning_rate']) + precision = configs['train']['precision'] + early_stop_steps = configs['train']['early_stop_steps'] + warm_up_steps = configs['train']['warm_up_steps'] + reduce_lr_steps = configs['train']['reduce_lr_steps'] + + # paths + checkpoints_dir, logs_dir, logger, statistics_path = get_dirs( + workspace, task_name, filename, config_yaml, gpus + ) + + # training data module + data_module = _get_data_module( + workspace=workspace, + config_yaml=config_yaml, + num_workers=num_workers, + distributed=distributed, + ) + + # batch data preprocessor + BatchDataPreprocessor = get_batch_data_preprocessor_class( + batch_data_preprocessor_type=batch_data_preprocessor_type + ) + + batch_data_preprocessor = BatchDataPreprocessor( + target_source_types=target_source_types + ) + + # model + Model = get_model_class(model_type=model_type) + model = Model(input_channels=channels, target_sources_num=target_sources_num) + + # loss function + loss_function = get_loss_function(loss_type=loss_type) + + # callbacks + callbacks = get_callbacks( + task_name=task_name, + config_yaml=config_yaml, + workspace=workspace, + checkpoints_dir=checkpoints_dir, + statistics_path=statistics_path, + logger=logger, + model=model, + evaluate_device=evaluate_device, + ) + # callbacks = [] + + # learning rate reduce function + lr_lambda = partial( + get_lr_lambda, warm_up_steps=warm_up_steps, reduce_lr_steps=reduce_lr_steps + ) + + # pytorch-lightning model + pl_model = LitSourceSeparation( + batch_data_preprocessor=batch_data_preprocessor, + model=model, + optimizer_type=optimizer_type, + loss_function=loss_function, + learning_rate=learning_rate, + lr_lambda=lr_lambda, + ) + + # trainer + trainer = pl.Trainer( + checkpoint_callback=False, + gpus=gpus, + callbacks=callbacks, + max_steps=early_stop_steps, + accelerator="ddp", + sync_batchnorm=True, + precision=precision, + replace_sampler_ddp=False, + plugins=[DDPPlugin(find_unused_parameters=True)], + profiler='simple', + ) + + # Fit, evaluate, and save checkpoints. + trainer.fit(pl_model, data_module) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="") + subparsers = parser.add_subparsers(dest="mode") + + parser_train = subparsers.add_parser("train") + parser_train.add_argument( + "--workspace", type=str, required=True, help="Directory of workspace." + ) + parser_train.add_argument("--gpus", type=int, required=True) + parser_train.add_argument( + "--config_yaml", + type=str, + required=True, + help="Path of config file for training.", + ) + + args = parser.parse_args() + args.filename = pathlib.Path(__file__).stem + + if args.mode == "train": + train(args) + + else: + raise Exception("Error argument!") diff --git a/bytesep/utils.py b/bytesep/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4a38928bd5b00521d32b67c484e5561ff2ead439 --- /dev/null +++ b/bytesep/utils.py @@ -0,0 +1,189 @@ +import datetime +import logging +import os +import pickle +from typing import Dict, NoReturn + +import librosa +import numpy as np +import yaml + + +def create_logging(log_dir: str, filemode: str) -> logging: + r"""Create logging to write out log files. + + Args: + logs_dir, str, directory to write out logs + filemode: str, e.g., "w" + + Returns: + logging + """ + os.makedirs(log_dir, exist_ok=True) + i1 = 0 + + while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): + i1 += 1 + + log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", + datefmt="%a, %d %b %Y %H:%M:%S", + filename=log_path, + filemode=filemode, + ) + + # Print to console + console = logging.StreamHandler() + console.setLevel(logging.INFO) + formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") + console.setFormatter(formatter) + logging.getLogger("").addHandler(console) + + return logging + + +def load_audio( + audio_path: str, + mono: bool, + sample_rate: float, + offset: float = 0.0, + duration: float = None, +) -> np.array: + r"""Load audio. + + Args: + audio_path: str + mono: bool + sample_rate: float + """ + audio, _ = librosa.core.load( + audio_path, sr=sample_rate, mono=mono, offset=offset, duration=duration + ) + # (audio_samples,) | (channels_num, audio_samples) + + if audio.ndim == 1: + audio = audio[None, :] + # (1, audio_samples,) + + return audio + + +def load_random_segment( + audio_path: str, random_state, segment_seconds: float, mono: bool, sample_rate: int +) -> np.array: + r"""Randomly select an audio segment from a recording.""" + + duration = librosa.get_duration(filename=audio_path) + + start_time = random_state.uniform(0.0, duration - segment_seconds) + + audio = load_audio( + audio_path=audio_path, + mono=mono, + sample_rate=sample_rate, + offset=start_time, + duration=segment_seconds, + ) + # (channels_num, audio_samples) + + return audio + + +def float32_to_int16(x: np.float32) -> np.int16: + + x = np.clip(x, a_min=-1, a_max=1) + + return (x * 32767.0).astype(np.int16) + + +def int16_to_float32(x: np.int16) -> np.float32: + + return (x / 32767.0).astype(np.float32) + + +def read_yaml(config_yaml: str): + + with open(config_yaml, "r") as fr: + configs = yaml.load(fr, Loader=yaml.FullLoader) + + return configs + + +def check_configs_gramma(configs: Dict) -> NoReturn: + r"""Check if the gramma of the config dictionary for training is legal.""" + input_source_types = configs['train']['input_source_types'] + + for augmentation_type in configs['train']['augmentations'].keys(): + augmentation_dict = configs['train']['augmentations'][augmentation_type] + + for source_type in augmentation_dict.keys(): + if source_type not in input_source_types: + error_msg = ( + "The source type '{}'' in configs['train']['augmentations']['{}'] " + "must be one of input_source_types {}".format( + source_type, augmentation_type, input_source_types + ) + ) + raise Exception(error_msg) + + +def magnitude_to_db(x: float) -> float: + eps = 1e-10 + return 20.0 * np.log10(max(x, eps)) + + +def db_to_magnitude(x: float) -> float: + return 10.0 ** (x / 20) + + +def get_pitch_shift_factor(shift_pitch: float) -> float: + r"""The factor of the audio length to be scaled.""" + return 2 ** (shift_pitch / 12) + + +class StatisticsContainer(object): + def __init__(self, statistics_path): + self.statistics_path = statistics_path + + self.backup_statistics_path = "{}_{}.pkl".format( + os.path.splitext(self.statistics_path)[0], + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), + ) + + self.statistics_dict = {"train": [], "test": []} + + def append(self, steps, statistics, split): + statistics["steps"] = steps + self.statistics_dict[split].append(statistics) + + def dump(self): + pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) + pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) + logging.info(" Dump statistics to {}".format(self.statistics_path)) + logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) + + ''' + def load_state_dict(self, resume_steps): + self.statistics_dict = pickle.load(open(self.statistics_path, "rb")) + + resume_statistics_dict = {"train": [], "test": []} + + for key in self.statistics_dict.keys(): + for statistics in self.statistics_dict[key]: + if statistics["steps"] <= resume_steps: + resume_statistics_dict[key].append(statistics) + + self.statistics_dict = resume_statistics_dict + ''' + + +def calculate_sdr(ref: np.array, est: np.array) -> float: + s_true = ref + s_artif = est - ref + sdr = 10.0 * ( + np.log10(np.clip(np.mean(s_true ** 2), 1e-8, np.inf)) + - np.log10(np.clip(np.mean(s_artif ** 2), 1e-8, np.inf)) + ) + return sdr diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..cf56f232d3071ece8bc520188b846c25f319a3c1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[tool.black] +line-length = 88 +target-version = ['py37'] +skip-string-normalization = true +include = '\.pyi?$' +exclude = ''' +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ +) +''' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1f3c378f499e9e4f7a25427fc8ab90dbbacd92b6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +torch==1.7.1 +h5py==2.10.0 +librosa==0.8.1 +numpy==1.20.3 +numba==0.54.0 +musdb==0.4.0 +pytorch_lightning==1.2.1 +torchlibrosa==0.0.9 +matplotlib==3.3.4 +museval==0.4.0 +inplace-abn==1.1.0 \ No newline at end of file diff --git a/scripts/0_download_datasets/instruments.sh b/scripts/0_download_datasets/instruments.sh new file mode 100755 index 0000000000000000000000000000000000000000..a848adbe45957923c47bc3047c33958a1421c8f6 --- /dev/null +++ b/scripts/0_download_datasets/instruments.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +echo "The dataset link is created internally by kqq" + +# The downloaded MAESTRO dataset looks like: +# ./datasets/instruments +# ├── violin_solo +# │ └── v0.1 +# │ ├── mp3s (12 files) +# │ │ ├── 0jXXWBt5URw.mp3 +# │ │ └── ... +# │ ├── README.txt +# │ └── validation.csv +# ├── basson_solo +# │ └── ... +# ├── cello_solo +# │ └── ... +# ├── clarinet_solo +# │ └── ... +# ├── flute_solo +# │ └── ... +# ├── harp_solo +# │ └── ... +# ├── horn_solo +# │ └── ... +# ├── oboe_solo +# │ └── ... +# ├── saxophone_solo +# │ └── ... +# ├── string_quartet +# │ └── ... +# ├── symphony_solo +# │ └── ... +# ├── timpani_solo +# │ └── ... +# ├── trombone_solo +# │ └── ... +# ├── trumpet_solo +# │ └── ... +# ├── tuba_solo +# │ └── ... +# └── viola_solo +# └── ... \ No newline at end of file diff --git a/scripts/0_download_datasets/maestro.sh b/scripts/0_download_datasets/maestro.sh new file mode 100755 index 0000000000000000000000000000000000000000..be7f5a78d642cb46a954c1175196b479a2e9f95d --- /dev/null +++ b/scripts/0_download_datasets/maestro.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +echo "The dataset link is at https://magenta.tensorflow.org/datasets/maestro" + +# The downloaded MAESTRO dataset looks like: +# ./datasets/maestro +# ├── 2004 (264 files) +# │ └── ... +# ├── 2006 (230 files) +# │ └── ... +# ├── 2008 (294 files) +# │ └── ... +# ├── 2009 (250 files) +# │ └── ... +# ├── 2011 (326 files) +# │ └── ... +# ├── 2013 (254 files) +# │ └── ... +# ├── 2014 (210 files) +# │ └── ... +# ├── 2015 (258 files) +# │ └── ... +# ├── 2017 (280 files) +# │ └── ... +# ├── 2018 (198 files) +# │ └── ... +# ├── LICENSE +# ├── maestro-v2.0.0.csv +# ├── maestro-v2.0.0.json +# └── README \ No newline at end of file diff --git a/scripts/0_download_datasets/musdb18.sh b/scripts/0_download_datasets/musdb18.sh new file mode 100755 index 0000000000000000000000000000000000000000..58ba01624b45a2f2d75cba4e3f1612833fc68d4b --- /dev/null +++ b/scripts/0_download_datasets/musdb18.sh @@ -0,0 +1,26 @@ +#!/bin/bash +MUSDB18_DATASET_DIR=${1:-"./datasets/musdb18"} # The first argument is dataset directory. + +echo "MUSDB18_DATASET_DIR=${MUSDB18_DATASET_DIR}" + +# Set up paths. +mkdir -p $MUSDB18_DATASET_DIR +cd $MUSDB18_DATASET_DIR + +# Download dataset from Zenodo. +echo "The dataset link is at https://zenodo.org/record/1117372" + +wget -O "musdb18.zip" "https://zenodo.org/record/1117372/files/musdb18.zip?download=1" + +# Unzip dataset. +unzip "musdb18.zip" + +# The downloaded MUSDB18 dataset looks like: +# ./datasets/musdb18 +# ├── train (100 files) +# │ ├── 'A Classic Education - NightOwl.stem.mp4' +# │ └── ... +# ├── test (50 files) +# │ ├── 'Al James - Schoolboy Facination.stem.mp4' +# │ └── ... +# └── README.md \ No newline at end of file diff --git a/scripts/0_download_datasets/vctk.sh b/scripts/0_download_datasets/vctk.sh new file mode 100755 index 0000000000000000000000000000000000000000..e4da624a6ff0bed67ca7905d7d3c3a9b6f8e7d38 --- /dev/null +++ b/scripts/0_download_datasets/vctk.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +echo "The dataset link is at http://www.udialogue.org/download/VCTK-Corpus.tar.gz" + +# The downloaded VCTK dataset looks like: +# ./datasets/vctk +# └── wav48 +# ├── train (100 speakers) +# │ ├── p225 (231 files) +# │ │ ├── p225_001_mic1.flac.wav +# │ │ └── ... +# │ ├── p226 (356 files) +# │ │ ├── p226_001_mic1.flac.wav +# │ │ └── ... +# │ └── ... +# └── test (8 speakers) +# ├── p360 (424 files) +# │ ├── p360_001_mic1.flac.wav +# │ └── ... +# ├── p226 (424 files) +# │ ├── p361_001_mic1.flac.wav +# │ └── ... +# └── ... \ No newline at end of file diff --git a/scripts/0_download_datasets/voicebank-demand.sh b/scripts/0_download_datasets/voicebank-demand.sh new file mode 100755 index 0000000000000000000000000000000000000000..ab87f267c0b95cbd44220c8bc23e82a0f1fae448 --- /dev/null +++ b/scripts/0_download_datasets/voicebank-demand.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +echo "The dataset link is at https://datashare.ed.ac.uk/handle/10283/2791" + +# The downloaded Voicebank-DEMAND dataset looks like: +# ./datasets/voicebank-demand +# ├── clean_trainset_wav (11572 files) +# │ ├── p226_001.wav +# │ └── ... +# ├── noisy_trainset_wav (11572 files) +# │ ├── p226_001.wav +# │ └── ... +# ├── clean_testset_wav (11572 files) +# │ ├── p232_001.wav +# │ └── ... +# └── noisy_testset_wav (11572 files) +# ├── p232_001.wav +# └── ... \ No newline at end of file diff --git a/scripts/1_pack_audios_to_hdf5s/instruments_solo/symphony/sr=44100,chn=2.sh b/scripts/1_pack_audios_to_hdf5s/instruments_solo/symphony/sr=44100,chn=2.sh new file mode 100755 index 0000000000000000000000000000000000000000..b3ba173d0a84ec895cfa6c131c75fa469191eb96 --- /dev/null +++ b/scripts/1_pack_audios_to_hdf5s/instruments_solo/symphony/sr=44100,chn=2.sh @@ -0,0 +1,25 @@ +#!/bin/bash +INSTRUMENTS_SOLO_DATASET_DIR=${1:-"./datasets/instruments_solo"} # The first argument is dataset directory. +WORKSPACE=${2:-"./workspaces/bytesep"} # The second argument is workspace directory. + +echo "INSTRUMENTS_SOLO_DATASET_DIR=${INSTRUMENTS_SOLO_DATASET_DIR}" +echo "WORKSPACE=${WORKSPACE}" + +# Users can change the following settings. +SAMPLE_RATE=44100 +CHANNELS=2 + +INSTRUMENT="symphony" + +# Paths +SUB_DATASET_DIR="${INSTRUMENTS_SOLO_DATASET_DIR}/${INSTRUMENT}_solo/v0.1" + +HDF5S_DIR="${WORKSPACE}/hdf5s/instruments_solo/${INSTRUMENT}/sr=${SAMPLE_RATE}_chn=${CHANNELS}/train" + +python3 bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py \ + --dataset_dir=$SUB_DATASET_DIR \ + --split="train" \ + --source_type=$INSTRUMENT \ + --hdf5s_dir=$HDF5S_DIR \ + --sample_rate=$SAMPLE_RATE \ + --channels=$CHANNELS \ No newline at end of file diff --git a/scripts/1_pack_audios_to_hdf5s/instruments_solo/violin/sr=44100,chn=2.sh b/scripts/1_pack_audios_to_hdf5s/instruments_solo/violin/sr=44100,chn=2.sh new file mode 100755 index 0000000000000000000000000000000000000000..a19ffa39548062d491ba43eebf8bbcba729da422 --- /dev/null +++ b/scripts/1_pack_audios_to_hdf5s/instruments_solo/violin/sr=44100,chn=2.sh @@ -0,0 +1,25 @@ +#!/bin/bash +INSTRUMENTS_SOLO_DATASET_DIR=${1:-"./datasets/instruments_solo"} # The first argument is dataset directory. +WORKSPACE=${2:-"./workspaces/bytesep"} # The second argument is workspace directory. + +echo "INSTRUMENTS_SOLO_DATASET_DIR=${INSTRUMENTS_SOLO_DATASET_DIR}" +echo "WORKSPACE=${WORKSPACE}" + +# Users can change the following settings. +SAMPLE_RATE=44100 +CHANNELS=2 + +INSTRUMENT="violin" + +# Paths +SUB_DATASET_DIR="${INSTRUMENTS_SOLO_DATASET_DIR}/${INSTRUMENT}_solo/v0.1" + +HDF5S_DIR="${WORKSPACE}/hdf5s/instruments_solo/${INSTRUMENT}/sr=${SAMPLE_RATE}_chn=${CHANNELS}/train" + +python3 bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py \ + --dataset_dir=$SUB_DATASET_DIR \ + --split="train" \ + --source_type=$INSTRUMENT \ + --hdf5s_dir=$HDF5S_DIR \ + --sample_rate=$SAMPLE_RATE \ + --channels=$CHANNELS \ No newline at end of file diff --git a/scripts/1_pack_audios_to_hdf5s/maestro/sr=44100,chn=2.sh b/scripts/1_pack_audios_to_hdf5s/maestro/sr=44100,chn=2.sh new file mode 100755 index 0000000000000000000000000000000000000000..05c239fb85749261920f85f4ccf70641f9f67546 --- /dev/null +++ b/scripts/1_pack_audios_to_hdf5s/maestro/sr=44100,chn=2.sh @@ -0,0 +1,20 @@ +#!/bin/bash +DATASET_DIR=${1:-"./datasets/maestro"} # The first argument is dataset directory. +WORKSPACE=${2:-"./workspaces/bytesep"} # The second argument is workspace directory. + +echo "DATASET_DIR=${DATASET_DIR}" +echo "WORKSPACE=${WORKSPACE}" + +# Users can change the following settings. +SAMPLE_RATE=44100 +CHANNELS=2 + +# Paths +HDF5S_DIR="${WORKSPACE}/hdf5s/maestro/sr=${SAMPLE_RATE}_chn=${CHANNELS}/train" + +python3 bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py \ + --dataset_dir=$DATASET_DIR \ + --split="train" \ + --hdf5s_dir=$HDF5S_DIR \ + --sample_rate=$SAMPLE_RATE \ + --channels=$CHANNELS \ No newline at end of file diff --git a/scripts/1_pack_audios_to_hdf5s/musdb18/sr=44100,chn=2.sh b/scripts/1_pack_audios_to_hdf5s/musdb18/sr=44100,chn=2.sh new file mode 100755 index 0000000000000000000000000000000000000000..8519fa3092f66291d416c70f9ca320b4b5b12d5a --- /dev/null +++ b/scripts/1_pack_audios_to_hdf5s/musdb18/sr=44100,chn=2.sh @@ -0,0 +1,24 @@ +#!/bin/bash +MUSDB18_DATASET_DIR=${1:-"./datasets/musdb18"} # The first argument is dataset directory. +WORKSPACE=${2:-"./workspaces/bytesep"} # The second argument is workspace directory. + +echo "MUSDB18_DATASET_DIR=${MUSDB18_DATASET_DIR}" +echo "WORKSPACE=${WORKSPACE}" + +# Users can change the following settings. +SAMPLE_RATE=44100 +CHANNELS=2 + +# Paths +PARENT_HDF5S_DIR="${WORKSPACE}/hdf5s/musdb18/sr=${SAMPLE_RATE}_chn=${CHANNELS}" + +# Pack train subset 100 pieces into hdf5 files. +HDF5S_DIR="${PARENT_HDF5S_DIR}/train/full_train" + +python3 bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py \ + --dataset_dir=$MUSDB18_DATASET_DIR \ + --subset="train" \ + --split="" \ + --hdf5s_dir=$HDF5S_DIR \ + --sample_rate=$SAMPLE_RATE \ + --channels=$CHANNELS \ No newline at end of file diff --git a/scripts/1_pack_audios_to_hdf5s/vctk/sr=44100,chn=2.sh b/scripts/1_pack_audios_to_hdf5s/vctk/sr=44100,chn=2.sh new file mode 100755 index 0000000000000000000000000000000000000000..71eac148ffaf44878df6692e92bb442614c30ce4 --- /dev/null +++ b/scripts/1_pack_audios_to_hdf5s/vctk/sr=44100,chn=2.sh @@ -0,0 +1,21 @@ +#!/bin/bash +DATASET_DIR=${1:-"./datasets/vctk"} # The first argument is dataset directory. +WORKSPACE=${2:-"./workspaces/bytesep"} # The second argument is workspace directory. + +echo "DATASET_DIR=${DATASET_DIR}" +echo "WORKSPACE=${WORKSPACE}" + +# Users can change the following settings. +SAMPLE_RATE=44100 +CHANNELS=2 + +# Paths +HDF5S_DIR="${WORKSPACE}/hdf5s/vctk/sr=${SAMPLE_RATE}_chn=${CHANNELS}/train" + +python3 bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py \ + --dataset_dir=$DATASET_DIR \ + --split="train" \ + --hdf5s_dir=$HDF5S_DIR \ + --sample_rate=$SAMPLE_RATE \ + --channels=$CHANNELS + \ No newline at end of file diff --git a/scripts/1_pack_audios_to_hdf5s/voicebank-demand/sr=44100,chn=1.sh b/scripts/1_pack_audios_to_hdf5s/voicebank-demand/sr=44100,chn=1.sh new file mode 100755 index 0000000000000000000000000000000000000000..b6864ddc299ee2149a5f52e4ed0ad543c207fb33 --- /dev/null +++ b/scripts/1_pack_audios_to_hdf5s/voicebank-demand/sr=44100,chn=1.sh @@ -0,0 +1,23 @@ +#!/bin/bash +DATASET_DIR=${1:-"./datasets/voicebank-demand"} # The first argument is dataset directory. +WORKSPACE=${2:-"./workspaces/bytesep"} # The second argument is workspace directory. + +echo "DATASET_DIR=${DATASET_DIR}" +echo "WORKSPACE=${WORKSPACE}" + +# Users can change the following settings. +SAMPLE_RATE=44100 +CHANNELS=1 + +# Paths +PARENT_HDF5S_DIR="${WORKSPACE}/hdf5s/voicebank-demand/sr=${SAMPLE_RATE}_chn=${CHANNELS}" + +# Pack train subset 100 pieces into hdf5 files. +HDF5S_DIR="${PARENT_HDF5S_DIR}/train" + +python3 bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py \ + --dataset_dir=$DATASET_DIR \ + --split="train" \ + --hdf5s_dir=$HDF5S_DIR \ + --sample_rate=$SAMPLE_RATE \ + --channels=$CHANNELS \ No newline at end of file diff --git a/scripts/2_create_indexes/musdb18/configs/vocals-accompaniment,sr=44100,chn=2.yaml b/scripts/2_create_indexes/musdb18/configs/vocals-accompaniment,sr=44100,chn=2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..f099f2be1526712b826fe47e0cdb3545bf50d38d --- /dev/null +++ b/scripts/2_create_indexes/musdb18/configs/vocals-accompaniment,sr=44100,chn=2.yaml @@ -0,0 +1,16 @@ +--- +sample_rate: 44100 +segment_seconds: 3.0 +train: + source_types: + vocals: + musdb18: + hdf5s_directory: "hdf5s/musdb18/sr=44100_chn=2/train/full_train" + key_in_hdf5: "vocals" + hop_seconds: 0.1 + accompaniment: + musdb18: + hdf5s_directory: "hdf5s/musdb18/sr=44100_chn=2/train/full_train" + key_in_hdf5: "accompaniment" + hop_seconds: 0.1 + indexes: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-accompaniment.pkl" diff --git a/scripts/2_create_indexes/musdb18/configs/vocals-bass-drums-other,sr=44100,chn=2.yaml b/scripts/2_create_indexes/musdb18/configs/vocals-bass-drums-other,sr=44100,chn=2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..0c6403da562992490d650fad4f77169c6af4978f --- /dev/null +++ b/scripts/2_create_indexes/musdb18/configs/vocals-bass-drums-other,sr=44100,chn=2.yaml @@ -0,0 +1,26 @@ +--- +sample_rate: 44100 +segment_seconds: 3.0 +train: + source_types: + vocals: + musdb18: + hdf5s_directory: "hdf5s/musdb18/sr=44100_chn=2/train/full_train" + key_in_hdf5: "vocals" + hop_seconds: 0.1 + bass: + musdb18: + hdf5s_directory: "hdf5s/musdb18/sr=44100_chn=2/train/full_train" + key_in_hdf5: "bass" + hop_seconds: 0.1 + drums: + musdb18: + hdf5s_directory: "hdf5s/musdb18/sr=44100_chn=2/train/full_train" + key_in_hdf5: "drums" + hop_seconds: 0.1 + other: + musdb18: + hdf5s_directory: "hdf5s/musdb18/sr=44100_chn=2/train/full_train" + key_in_hdf5: "other" + hop_seconds: 0.1 + indexes: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-bass-drums-other.pkl" diff --git a/scripts/2_create_indexes/musdb18/create_indexes.sh b/scripts/2_create_indexes/musdb18/create_indexes.sh new file mode 100755 index 0000000000000000000000000000000000000000..fc571ebd1971ce44b973b878a83ac54ebfb47948 --- /dev/null +++ b/scripts/2_create_indexes/musdb18/create_indexes.sh @@ -0,0 +1,18 @@ +#!/bin/bash +WORKSPACE=${1:-"./workspaces/bytesep"} # Default workspace directory + +echo "WORKSPACE=${WORKSPACE}" + +# --- Create indexes for vocals and accompaniment --- +INDEXES_CONFIG_YAML="scripts/2_create_indexes/musdb18/configs/vocals-accompaniment,sr=44100,chn=2.yaml" + +python3 bytesep/dataset_creation/create_indexes/create_indexes.py \ + --workspace=$WORKSPACE \ + --config_yaml=$INDEXES_CONFIG_YAML + +# --- Create indexes for vocals, bass, drums, and other --- +INDEXES_CONFIG_YAML="scripts/2_create_indexes/musdb18/configs/vocals-bass-drums-other,sr=44100,chn=2.yaml" + +python3 bytesep/dataset_creation/create_indexes/create_indexes.py \ + --workspace=$WORKSPACE \ + --config_yaml=$INDEXES_CONFIG_YAML diff --git a/scripts/2_create_indexes/piano-symphony/configs/piano-symphony,sr=44100,chn=2.yaml b/scripts/2_create_indexes/piano-symphony/configs/piano-symphony,sr=44100,chn=2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..f8231d1e4cb7c34b932a44c81a520daec59c790e --- /dev/null +++ b/scripts/2_create_indexes/piano-symphony/configs/piano-symphony,sr=44100,chn=2.yaml @@ -0,0 +1,16 @@ +--- +sample_rate: 44100 +segment_seconds: 3.0 +train: + source_types: + piano: + maestro: + hdf5s_directory: "hdf5s/maestro/sr=44100_chn=2/train" + key_in_hdf5: "piano" + hop_seconds: 3.0 + symphony: + instruments_solo: + hdf5s_directory: "hdf5s/instruments_solo/symphony/sr=44100_chn=2/train" + key_in_hdf5: "symphony" + hop_seconds: 0.1 + indexes: "indexes/piano-symphony/sr=44100_chn=2/train/piano-symphony.pkl" diff --git a/scripts/2_create_indexes/piano-symphony/create_indexes.sh b/scripts/2_create_indexes/piano-symphony/create_indexes.sh new file mode 100755 index 0000000000000000000000000000000000000000..eff1369ac25ff22c70cfbacb77eeb77ac5377e7e --- /dev/null +++ b/scripts/2_create_indexes/piano-symphony/create_indexes.sh @@ -0,0 +1,12 @@ +#!/bin/bash +WORKSPACE=${1:-"./workspaces/bytesep"} # Default workspace directory + +echo "WORKSPACE=${WORKSPACE}" + +# Users can modify the following config file. +INDEXES_CONFIG_YAML="scripts/2_create_indexes/piano-symphony/configs/piano-symphony,sr=44100,chn=2.yaml" + +# Create indexes for training. +python3 bytesep/dataset_creation/create_indexes/create_indexes.py \ + --workspace=$WORKSPACE \ + --config_yaml=$INDEXES_CONFIG_YAML diff --git a/scripts/2_create_indexes/vctk-musdb18/configs/speech-accompaniment,sr=44100,chn=2.yaml b/scripts/2_create_indexes/vctk-musdb18/configs/speech-accompaniment,sr=44100,chn=2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..ad6ad5d1e49a2f0649100abfb3b4d967f39b6d30 --- /dev/null +++ b/scripts/2_create_indexes/vctk-musdb18/configs/speech-accompaniment,sr=44100,chn=2.yaml @@ -0,0 +1,16 @@ +--- +sample_rate: 44100 +segment_seconds: 3.0 +train: + source_types: + speech: + vctk: + hdf5s_directory: "hdf5s/vctk/sr=44100_chn=2/train" + key_in_hdf5: "speech" + hop_seconds: 0.1 + accompaniment: + musdb18: + hdf5s_directory: "hdf5s/musdb18/sr=44100_chn=2/train/full_train" + key_in_hdf5: "accompaniment" + hop_seconds: 0.1 + indexes: "indexes/vctk-musdb18/sr=44100_chn=2/train/speech-accompaniment.pkl" diff --git a/scripts/2_create_indexes/vctk-musdb18/configs/speech-music,sr=44100,chn=2.yaml b/scripts/2_create_indexes/vctk-musdb18/configs/speech-music,sr=44100,chn=2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..5f30ffc944eb625f21cc482997f54deac497af66 --- /dev/null +++ b/scripts/2_create_indexes/vctk-musdb18/configs/speech-music,sr=44100,chn=2.yaml @@ -0,0 +1,16 @@ +--- +sample_rate: 44100 +segment_seconds: 3.0 +train: + source_types: + speech: + vctk: + hdf5s_directory: "hdf5s/vctk/sr=44100_chn=2/train" + key_in_hdf5: "speech" + hop_seconds: 0.1 + music: + musdb18: + hdf5s_directory: "hdf5s/musdb18/sr=44100_chn=2/train/full_train" + key_in_hdf5: "mixture" + hop_seconds: 0.1 + indexes: "indexes/vctk-musdb18/sr=44100_chn=2/train/speech-music.pkl" diff --git a/scripts/2_create_indexes/vctk-musdb18/create_indexes.sh b/scripts/2_create_indexes/vctk-musdb18/create_indexes.sh new file mode 100755 index 0000000000000000000000000000000000000000..e2a85230b2745cedb2c98a34ed303082bb1ec48a --- /dev/null +++ b/scripts/2_create_indexes/vctk-musdb18/create_indexes.sh @@ -0,0 +1,12 @@ +#!/bin/bash +WORKSPACE=${1:-"./workspaces/bytesep"} # Default workspace directory + +echo "WORKSPACE=${WORKSPACE}" + +# Users can modify the following config file. +INDEXES_CONFIG_YAML="scripts/2_create_indexes/vctk-musdb18/configs/speech-accompaniment,sr=44100,chn=2.yaml" + +# Create indexes for training. +python3 bytesep/dataset_creation/create_indexes/create_indexes.py \ + --workspace=$WORKSPACE \ + --config_yaml=$INDEXES_CONFIG_YAML diff --git a/scripts/2_create_indexes/violin-piano/configs/violin-piano,sr=44100,chn=2.yaml b/scripts/2_create_indexes/violin-piano/configs/violin-piano,sr=44100,chn=2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..1d73d365f13fd7bf5ac9968b2803fb2be1947406 --- /dev/null +++ b/scripts/2_create_indexes/violin-piano/configs/violin-piano,sr=44100,chn=2.yaml @@ -0,0 +1,16 @@ +--- +sample_rate: 44100 +segment_seconds: 3.0 +train: + source_types: + violin: + instruments_solo: + hdf5s_directory: "hdf5s/instruments_solo/violin/sr=44100_chn=2/train" + key_in_hdf5: "violin" + hop_seconds: 0.1 + piano: + maestro: + hdf5s_directory: "hdf5s/maestro/sr=44100_chn=2/train" + key_in_hdf5: "piano" + hop_seconds: 3.0 + indexes: "indexes/violin-piano/sr=44100_chn=2/train/violin-piano.pkl" diff --git a/scripts/2_create_indexes/violin-piano/create_indexes.sh b/scripts/2_create_indexes/violin-piano/create_indexes.sh new file mode 100755 index 0000000000000000000000000000000000000000..f1532a524bdae6b5316451bb0172b3b44c8f4843 --- /dev/null +++ b/scripts/2_create_indexes/violin-piano/create_indexes.sh @@ -0,0 +1,12 @@ +#!/bin/bash +WORKSPACE=${1:-"./workspaces/bytesep"} # Default workspace directory + +echo "WORKSPACE=${WORKSPACE}" + +# Users can modify the following config file. +INDEXES_CONFIG_YAML="scripts/2_create_indexes/violin-piano/configs/violin-piano,sr=44100,chn=2.yaml" + +# Create indexes for training. +python3 bytesep/dataset_creation/create_indexes/create_indexes.py \ + --workspace=$WORKSPACE \ + --config_yaml=$INDEXES_CONFIG_YAML diff --git a/scripts/2_create_indexes/voicebank-demand/configs/speech-noise,sr=44100,chn=1.yaml b/scripts/2_create_indexes/voicebank-demand/configs/speech-noise,sr=44100,chn=1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..37bfe825b6a003f19e3bdf8d0a6bdac819042d9a --- /dev/null +++ b/scripts/2_create_indexes/voicebank-demand/configs/speech-noise,sr=44100,chn=1.yaml @@ -0,0 +1,16 @@ +--- +sample_rate: 44100 +segment_seconds: 3.0 +train: + source_types: + speech: + voicebank-demand: + hdf5s_directory: "hdf5s/voicebank-demand/sr=44100_chn=1/train" + key_in_hdf5: "speech" + hop_seconds: 3. + noise: + voicebank-demand: + hdf5s_directory: "hdf5s/voicebank-demand/sr=44100_chn=1/train" + key_in_hdf5: "noise" + hop_seconds: 3. + indexes: "indexes/voicebank-demand/sr=44100_chn=1/train/speech-noise.pkl" diff --git a/scripts/2_create_indexes/voicebank-demand/create_indexes.sh b/scripts/2_create_indexes/voicebank-demand/create_indexes.sh new file mode 100755 index 0000000000000000000000000000000000000000..f29c64f0a8b1d8a251061d336a15c953a468eb7f --- /dev/null +++ b/scripts/2_create_indexes/voicebank-demand/create_indexes.sh @@ -0,0 +1,12 @@ +#!/bin/bash +WORKSPACE=${1:-"./workspaces/bytesep"} # Default workspace directory + +echo "WORKSPACE=${WORKSPACE}" + +# Users can modify the following config file. +INDEXES_CONFIG_YAML="scripts/2_create_indexes/voicebank-demand/configs/voicebank-demand,sr=44100,chn=1.yaml" + +# Create indexes for training. +python3 bytesep/dataset_creation/create_indexes/create_indexes.py \ + --workspace=$WORKSPACE \ + --config_yaml=$INDEXES_CONFIG_YAML diff --git a/scripts/3_create_evaluation_audios/musdb18/create_evaluation_audios.sh b/scripts/3_create_evaluation_audios/musdb18/create_evaluation_audios.sh new file mode 100755 index 0000000000000000000000000000000000000000..2f444e4060978a3e15e8b9d26fa96454147b0517 --- /dev/null +++ b/scripts/3_create_evaluation_audios/musdb18/create_evaluation_audios.sh @@ -0,0 +1,14 @@ +#!/bin/bash +MUSDB18_DATASET_DIR=${1:-"./datasets/musdb18"} # The first argument is dataset directory. +WORKSPACE=${2:-"./workspaces/bytesep"} # The first argument is workspace directory. + +# Get absolute path +MUSDB18_DATASET_DIR=`readlink -f $MUSDB18_DATASET_DIR` + +# Evaluation audios directory +EVALUATION_AUDIOS_DIR="${WORKSPACE}/evaluation_audios/musdb18" + +mkdir -p `dirname $EVALUATION_AUDIOS_DIR` + +# Create link +ln -s $MUSDB18_DATASET_DIR $EVALUATION_AUDIOS_DIR \ No newline at end of file diff --git a/scripts/3_create_evaluation_audios/piano-symphony/create_evaluation_audios.sh b/scripts/3_create_evaluation_audios/piano-symphony/create_evaluation_audios.sh new file mode 100755 index 0000000000000000000000000000000000000000..517ea426871105ffaf6bd9d0b5de2db5cb869b00 --- /dev/null +++ b/scripts/3_create_evaluation_audios/piano-symphony/create_evaluation_audios.sh @@ -0,0 +1,19 @@ +#!/bin/bash +PIANO_DATASET_DIR=${1:-"./datasets/maestro"} +SYMPHONY_DATASET_DIR=${2:-"./datasets/instruments_solo/symphony_solo/v0.1"} +WORKSPACE=${3:-"./workspaces/bytesep"} + +SAMPLE_RATE=44100 +CHANNELS=2 +EVALUATION_SEGMENTS_NUM=100 + +EVLUATION_AUDIOS_DIR="${WORKSPACE}/evaluation_audios/piano-symphony" + +python3 bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py \ + --piano_dataset_dir=$PIANO_DATASET_DIR \ + --symphony_dataset_dir=$SYMPHONY_DATASET_DIR \ + --evaluation_audios_dir=$EVLUATION_AUDIOS_DIR \ + --sample_rate=$SAMPLE_RATE \ + --channels=$CHANNELS \ + --evaluation_segments_num=$EVALUATION_SEGMENTS_NUM + \ No newline at end of file diff --git a/scripts/3_create_evaluation_audios/vctk-musdb18/create_evaluation_audios.sh b/scripts/3_create_evaluation_audios/vctk-musdb18/create_evaluation_audios.sh new file mode 100755 index 0000000000000000000000000000000000000000..b12a57c6e2ddafe7e9db2d9240b58d00898b2c8a --- /dev/null +++ b/scripts/3_create_evaluation_audios/vctk-musdb18/create_evaluation_audios.sh @@ -0,0 +1,19 @@ +#!/bin/bash +VCTK_DATASET_DIR=${1:-"./datasets/vctk"} +MUSDB18_DATASET_DIR=${2:-"./datasets/musdb18"} +WORKSPACE=${3:-"./workspaces/bytesep"} + +SAMPLE_RATE=44100 +CHANNELS=2 +EVALUATION_SEGMENTS_NUM=100 + +EVLUATION_AUDIOS_DIR="${WORKSPACE}/evaluation_audios/vctk-musdb18" + +python3 bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py \ + --vctk_dataset_dir=$VCTK_DATASET_DIR \ + --musdb18_dataset_dir=$MUSDB18_DATASET_DIR \ + --evaluation_audios_dir=$EVLUATION_AUDIOS_DIR \ + --sample_rate=$SAMPLE_RATE \ + --channels=$CHANNELS \ + --evaluation_segments_num=$EVALUATION_SEGMENTS_NUM + \ No newline at end of file diff --git a/scripts/3_create_evaluation_audios/violin-piano/create_evaluation_audios.sh b/scripts/3_create_evaluation_audios/violin-piano/create_evaluation_audios.sh new file mode 100755 index 0000000000000000000000000000000000000000..6ed036ff64b7d6d8cf5fbdaf67a493da64fb3fc5 --- /dev/null +++ b/scripts/3_create_evaluation_audios/violin-piano/create_evaluation_audios.sh @@ -0,0 +1,19 @@ +#!/bin/bash +VIOLIN_DATASET_DIR=${1:-"./datasets/instruments_solo/violin_solo/v0.1"} +PIANO_DATASET_DIR=${2:-"./datasets/maestro"} +WORKSPACE=${3:-"./workspaces/bytesep"} + +SAMPLE_RATE=44100 +CHANNELS=2 +EVALUATION_SEGMENTS_NUM=100 + +EVLUATION_AUDIOS_DIR="${WORKSPACE}/evaluation_audios/violin-piano" + +python3 bytesep/dataset_creation/create_evaluation_audios/violin-piano.py \ + --violin_dataset_dir=$VIOLIN_DATASET_DIR \ + --piano_dataset_dir=$PIANO_DATASET_DIR \ + --evaluation_audios_dir=$EVLUATION_AUDIOS_DIR \ + --sample_rate=$SAMPLE_RATE \ + --channels=$CHANNELS \ + --evaluation_segments_num=$EVALUATION_SEGMENTS_NUM + \ No newline at end of file diff --git a/scripts/3_create_evaluation_audios/voicebank-demand/create_evaluation_audios.sh b/scripts/3_create_evaluation_audios/voicebank-demand/create_evaluation_audios.sh new file mode 100755 index 0000000000000000000000000000000000000000..df6791197a68b9941ffd4546cd38c36c8c8f2893 --- /dev/null +++ b/scripts/3_create_evaluation_audios/voicebank-demand/create_evaluation_audios.sh @@ -0,0 +1,14 @@ +#!/bin/bash +VOICEBANK_DEMAND_DATASET_DIR=${1:-"./datasets/voicebank-demand"} # The first argument is dataset directory. +WORKSPACE=${2:-"./workspaces/bytesep"} # The first argument is workspace directory. + +# Get absolute path +VOICEBANK_DEMAND_DATASET_DIR=`readlink -f $VOICEBANK_DEMAND_DATASET_DIR` + +# Evaluation audios directory +EVALUATION_AUDIOS_DIR="${WORKSPACE}/evaluation_audios/voicebank-demand" + +mkdir -p `dirname $EVALUATION_AUDIOS_DIR` + +# Create link +ln -s $VOICEBANK_DEMAND_DATASET_DIR $EVALUATION_AUDIOS_DIR \ No newline at end of file diff --git a/scripts/4_train/musdb18/configs/accompaniment-vocals,resunet.yaml b/scripts/4_train/musdb18/configs/accompaniment-vocals,resunet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..688dd9e8f91ef7284c96388f984e48f1eed75a23 --- /dev/null +++ b/scripts/4_train/musdb18/configs/accompaniment-vocals,resunet.yaml @@ -0,0 +1,46 @@ +--- +task_name: musdb18 +train: + input_source_types: + - vocals + - accompaniment + target_source_types: + - accompaniment + indexes_dict: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-accompaniment.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: ResUNet143_DecouplePlusInplaceABN_ISMIR2021 + loss_type: l1_wav + mini_data: False + augmentations: + mixaudio: + vocals: 2 + accompaniment: 2 + pitch_shift: + vocals: 0 + accompaniment: 0 + magnitude_scale: + vocals: + lower_db: 0 + higher_db: 0 + accompaniment: + lower_db: 0 + higher_db: 0 + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/musdb18/configs/accompaniment-vocals,resunet_ismir2021.yaml b/scripts/4_train/musdb18/configs/accompaniment-vocals,resunet_ismir2021.yaml new file mode 100755 index 0000000000000000000000000000000000000000..60d7e8e93659b4d972fd900f662db2f2b1a80958 --- /dev/null +++ b/scripts/4_train/musdb18/configs/accompaniment-vocals,resunet_ismir2021.yaml @@ -0,0 +1,53 @@ +--- +task_name: musdb18 +train: + input_source_types: + - vocals + - accompaniment + target_source_types: + - accompaniment + indexes_dict: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-accompaniment.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: ResUNet143_DecouplePlusInplaceABN_ISMIR2021 + loss_type: l1_wav + optimizer_type: Adam + mini_data: False + augmentations: + mixaudio: + vocals: 2 + accompaniment: 2 + pitch_shift: + vocals: 0 + accompaniment: 0 + magnitude_scale: + vocals: + lower_db: 0 + higher_db: 0 + accompaniment: + lower_db: 0 + higher_db: 0 + swap_channel: + vocals: False + accompaniment: False + flip_axis: + vocals: False + accompaniment: False + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/musdb18/configs/accompaniment-vocals,resunet_subbandtime.yaml b/scripts/4_train/musdb18/configs/accompaniment-vocals,resunet_subbandtime.yaml new file mode 100755 index 0000000000000000000000000000000000000000..2746ed0b90bdb5be6ff466b9e32ad3d4a5ebbdd6 --- /dev/null +++ b/scripts/4_train/musdb18/configs/accompaniment-vocals,resunet_subbandtime.yaml @@ -0,0 +1,53 @@ +--- +task_name: musdb18 +train: + input_source_types: + - vocals + - accompaniment + target_source_types: + - accompaniment + indexes_dict: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-accompaniment.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: ResUNet143_Subbandtime + loss_type: l1_wav + optimizer_type: Adam + mini_data: False + augmentations: + mixaudio: + vocals: 2 + accompaniment: 2 + pitch_shift: + vocals: 0 + accompaniment: 0 + magnitude_scale: + vocals: + lower_db: 0 + higher_db: 0 + accompaniment: + lower_db: 0 + higher_db: 0 + swap_channel: + vocals: False + accompaniment: False + flip_axis: + vocals: False + accompaniment: False + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/musdb18/configs/accompaniment-vocals,unet.yaml b/scripts/4_train/musdb18/configs/accompaniment-vocals,unet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..925d19569ff7edbfbea6fd014de67f7da3a39d86 --- /dev/null +++ b/scripts/4_train/musdb18/configs/accompaniment-vocals,unet.yaml @@ -0,0 +1,46 @@ +--- +task_name: musdb18 +train: + input_source_types: + - vocals + - accompaniment + target_source_types: + - accompaniment + indexes_dict: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-accompaniment.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: UNet + loss_type: l1_wav + mini_data: False + augmentations: + mixaudio: + vocals: 2 + accompaniment: 2 + pitch_shift: + vocals: 0 + accompaniment: 0 + magnitude_scale: + vocals: + lower_db: 0 + higher_db: 0 + accompaniment: + lower_db: 0 + higher_db: 0 + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/musdb18/configs/vocals-accompaniment,resunet.yaml b/scripts/4_train/musdb18/configs/vocals-accompaniment,resunet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..469d83f8e808df66783588ec52ba81a987e053e7 --- /dev/null +++ b/scripts/4_train/musdb18/configs/vocals-accompaniment,resunet.yaml @@ -0,0 +1,46 @@ +--- +task_name: musdb18 +train: + input_source_types: + - vocals + - accompaniment + target_source_types: + - vocals + indexes_dict: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-accompaniment.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: ResUNet143_DecouplePlusInplaceABN_ISMIR2021 + loss_type: l1_wav + mini_data: False + augmentations: + mixaudio: + vocals: 2 + accompaniment: 2 + pitch_shift: + vocals: 0 + accompaniment: 0 + magnitude_scale: + vocals: + lower_db: 0 + higher_db: 0 + accompaniment: + lower_db: 0 + higher_db: 0 + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_ismir2021.yaml b/scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_ismir2021.yaml new file mode 100755 index 0000000000000000000000000000000000000000..3a16e5c7620459af0fee89f428e96a257b2af20d --- /dev/null +++ b/scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_ismir2021.yaml @@ -0,0 +1,53 @@ +--- +task_name: musdb18 +train: + input_source_types: + - vocals + - accompaniment + target_source_types: + - vocals + indexes_dict: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-accompaniment.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: ResUNet143_DecouplePlusInplaceABN_ISMIR2021 + loss_type: l1_wav + optimizer_type: Adam + mini_data: False + augmentations: + mixaudio: + vocals: 2 + accompaniment: 2 + pitch_shift: + vocals: 0 + accompaniment: 0 + magnitude_scale: + vocals: + lower_db: 0 + higher_db: 0 + accompaniment: + lower_db: 0 + higher_db: 0 + swap_channel: + vocals: False + accompaniment: False + flip_axis: + vocals: False + accompaniment: False + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_subbandtime.yaml b/scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_subbandtime.yaml new file mode 100755 index 0000000000000000000000000000000000000000..11e253d3eaf5b7b044a01f1cce343542f86e261e --- /dev/null +++ b/scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_subbandtime.yaml @@ -0,0 +1,53 @@ +--- +task_name: musdb18 +train: + input_source_types: + - vocals + - accompaniment + target_source_types: + - vocals + indexes_dict: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-accompaniment.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: ResUNet143_Subbandtime + loss_type: l1_wav + optimizer_type: Adam + mini_data: False + augmentations: + mixaudio: + vocals: 2 + accompaniment: 2 + pitch_shift: + vocals: 0 + accompaniment: 0 + magnitude_scale: + vocals: + lower_db: 0 + higher_db: 0 + accompaniment: + lower_db: 0 + higher_db: 0 + swap_channel: + vocals: False + accompaniment: False + flip_axis: + vocals: False + accompaniment: False + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/musdb18/configs/vocals-accompaniment,unet.yaml b/scripts/4_train/musdb18/configs/vocals-accompaniment,unet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ea374caec71226b5b8348a55efdc0df53a88fcaa --- /dev/null +++ b/scripts/4_train/musdb18/configs/vocals-accompaniment,unet.yaml @@ -0,0 +1,46 @@ +--- +task_name: musdb18 +train: + input_source_types: + - vocals + - accompaniment + target_source_types: + - vocals + indexes_dict: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-accompaniment.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: UNet + loss_type: l1_wav + mini_data: False + augmentations: + mixaudio: + vocals: 2 + accompaniment: 2 + pitch_shift: + vocals: 0 + accompaniment: 0 + magnitude_scale: + vocals: + lower_db: 0 + higher_db: 0 + accompaniment: + lower_db: 0 + higher_db: 0 + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/musdb18/configs/vocals-accompaniment,unet_subbandtime.yaml b/scripts/4_train/musdb18/configs/vocals-accompaniment,unet_subbandtime.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e48ddbc5fcd2fed3c44d2ddca84408fea867a367 --- /dev/null +++ b/scripts/4_train/musdb18/configs/vocals-accompaniment,unet_subbandtime.yaml @@ -0,0 +1,46 @@ +--- +task_name: musdb18 +train: + input_source_types: + - vocals + - accompaniment + target_source_types: + - vocals + indexes_dict: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-accompaniment.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: UNetSubbandTime + loss_type: l1_wav + mini_data: False + augmentations: + mixaudio: + vocals: 2 + accompaniment: 2 + pitch_shift: + vocals: 0 + accompaniment: 0 + magnitude_scale: + vocals: + lower_db: 0 + higher_db: 0 + accompaniment: + lower_db: 0 + higher_db: 0 + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/musdb18/configs/vocals-bass-drums-other,unet.yaml b/scripts/4_train/musdb18/configs/vocals-bass-drums-other,unet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e50434e9ed180b341c97661a1917ab75479a9706 --- /dev/null +++ b/scripts/4_train/musdb18/configs/vocals-bass-drums-other,unet.yaml @@ -0,0 +1,61 @@ +--- +task_name: musdb18 +train: + input_source_types: + - vocals + - bass + - drums + - other + target_source_types: + - vocals + - bass + - drums + - other + indexes_dict: "indexes/musdb18/sr=44100_chn=2/train/full_train/vocals-bass-drums-other.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: UNet + loss_type: l1_wav + mini_data: False + augmentations: + mixaudio: + vocals: 1 + bass: 1 + drums: 1 + other: 1 + pitch_shift: + vocals: 0 + bass: 0 + drums: 0 + other: 0 + magnitude_scale: + vocals: + lower_db: 0 + higher_db: 0 + bass: + lower_db: 0 + higher_db: 0 + drums: + lower_db: 0 + higher_db: 0 + other: + lower_db: 0 + higher_db: 0 + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 20000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/musdb18/train.sh b/scripts/4_train/musdb18/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..b453042b2750fccab709710ed31199e1eb694427 --- /dev/null +++ b/scripts/4_train/musdb18/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash +WORKSPACE=${1:-"./workspaces/bytesep"} # The first argument is workspace directory. + +echo "WORKSPACE=${WORKSPACE}" + +# Users can modify the following config file. +TRAIN_CONFIG_YAML="scripts/4_train/musdb18/configs/vocals-accompaniment,unet.yaml" + +# Train & evaluate & save checkpoints. +CUDA_VISIBLE_DEVICES=0 python3 bytesep/train.py train \ + --workspace=$WORKSPACE \ + --gpus=1 \ + --config_yaml=$TRAIN_CONFIG_YAML \ No newline at end of file diff --git a/scripts/4_train/vctk-musdb18/configs/speech-accompaniment,unet.yaml b/scripts/4_train/vctk-musdb18/configs/speech-accompaniment,unet.yaml new file mode 100755 index 0000000000000000000000000000000000000000..3daeff03904270e5c173ed23090f9839b6a647a7 --- /dev/null +++ b/scripts/4_train/vctk-musdb18/configs/speech-accompaniment,unet.yaml @@ -0,0 +1,46 @@ +--- +task_name: vctk-musdb18 +train: + input_source_types: + - speech + - accompaniment + target_source_types: + - speech + indexes_dict: "indexes/vctk-musdb18/sr=44100_chn=2/train/speech-accompaniment.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: UNet + loss_type: l1_wav + mini_data: False + mixaudio: + speech: 1 + accompaniment: 1 + augmentation: + pitch_shift: + speech: 0 + accompaniment: 0 + magnitude_scale: + speech: + lower_db: -10 + higher_db: 10 + accompaniment: + lower_db: -30 + higher_db: 10 + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/vctk-musdb18/train.sh b/scripts/4_train/vctk-musdb18/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..e64648c63f465981aa5fdea48a983ba78fe22259 --- /dev/null +++ b/scripts/4_train/vctk-musdb18/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash +WORKSPACE=${1:-"./workspaces/bytesep"} # The first argument is workspace directory. + +echo "WORKSPACE=${WORKSPACE}" + +# Users can modify the following config file. +TRAIN_CONFIG_YAML="scripts/4_train/vctk-musdb18/configs/speech-accompaniment,unet.yaml" + +# Train & evaluate & save checkpoints. +CUDA_VISIBLE_DEVICES=0 python3 bytesep/train.py train \ + --workspace=$WORKSPACE \ + --gpus=1 \ + --config_yaml=$TRAIN_CONFIG_YAML \ No newline at end of file diff --git a/scripts/4_train/violin-piano/configs/violin-piano,unet.yaml b/scripts/4_train/violin-piano/configs/violin-piano,unet.yaml new file mode 100755 index 0000000000000000000000000000000000000000..0b9603ecd7dad503787f7410e068f1d16a1e843f --- /dev/null +++ b/scripts/4_train/violin-piano/configs/violin-piano,unet.yaml @@ -0,0 +1,40 @@ +--- +task_name: violin-piano +train: + input_source_types: + - violin + - piano + target_source_types: + - violin + indexes_dict: "indexes/violin-piano/sr=44100_chn=2/train/violin-piano.pkl" + sample_rate: 44100 + channels: 2 + segment_seconds: 3.0 + model_type: UNet + loss_type: l1_wav + mini_data: False + mixaudio: + violin: 1 + piano: 1 + augmentation: + pitch_shift: 0 + magnitude_scale: + lower_db: 0 + higher_db: 0 + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Musdb18EvaluationCallback + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + batch_size: 1 + segment_seconds: 30. \ No newline at end of file diff --git a/scripts/4_train/violin-piano/train.sh b/scripts/4_train/violin-piano/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..42fec4bdfd1b2bf38d2af4a28f55d80d3d0158ae --- /dev/null +++ b/scripts/4_train/violin-piano/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash +WORKSPACE=${1:-"./workspaces/bytesep"} # The first argument is workspace directory. + +echo "WORKSPACE=${WORKSPACE}" + +# Users can modify the following config file. +TRAIN_CONFIG_YAML="scripts/4_train/violin-piano/configs/violin-piano,unet.yaml" + +# Train & evaluate & save checkpoints. +CUDA_VISIBLE_DEVICES=0 python3 bytesep/train.py train \ + --workspace=$WORKSPACE \ + --gpus=1 \ + --config_yaml=$TRAIN_CONFIG_YAML \ No newline at end of file diff --git a/scripts/4_train/voicebank-demand/configs/speech-noise,unet.yaml b/scripts/4_train/voicebank-demand/configs/speech-noise,unet.yaml new file mode 100755 index 0000000000000000000000000000000000000000..e54b362b9568c0592a2a83017529518a4d793aa9 --- /dev/null +++ b/scripts/4_train/voicebank-demand/configs/speech-noise,unet.yaml @@ -0,0 +1,53 @@ +--- +task_name: voicebank-demand +train: + input_source_types: + - speech + - noise + target_source_types: + - speech + indexes_dict: "indexes/voicebank-demand/sr=44100_chn=1/train/speech-noise.pkl" + sample_rate: 44100 + channels: 1 + segment_seconds: 3.0 + model_type: UNet + loss_type: l1_wav + optimizer_type: Adam + mini_data: False + augmentations: + mixaudio: + speech: 1 + noise: 1 + pitch_shift: + speech: 0 + noise: 0 + magnitude_scale: + speech: + lower_db: 0 + higher_db: 0 + noise: + lower_db: 0 + higher_db: 0 + swap_channel: + speech: False + noise: False + flip_axis: + speech: False + noise: False + batch_data_preprocessor: BasicBatchDataPreprocessor + evaluation_callback: Default + learning_rate: 1e-3 + batch_size: 16 + precision: 32 + steps_per_epoch: 10000 # Every 10000 steps is called an `epoch`. + evaluate_step_frequency: 10000 # Evaluate every #evaluate_step_frequency steps. + save_step_frequency: 50000 # Save every #save_step_frequency steps. + early_stop_steps: 500001 + warm_up_steps: 1000 + reduce_lr_steps: 15000 + random_seed: 1234 + resume_checkpoint: "" + +evaluate: + segment_seconds: 30.0 + batch_size: 1 \ No newline at end of file diff --git a/scripts/4_train/voicebank-demand/train.sh b/scripts/4_train/voicebank-demand/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..4968af845ee9aaa4b252d0d434dae3aadac82e0a --- /dev/null +++ b/scripts/4_train/voicebank-demand/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash +WORKSPACE=${1:-"./workspaces/bytesep"} # The first argument is workspace directory. + +echo "WORKSPACE=${WORKSPACE}" + +# Users can modify the following config file. +TRAIN_CONFIG_YAML="scripts/4_train/voicebank-demand/configs/speech-noise,unet.yaml" + +# Train & evaluate & save checkpoints. +CUDA_VISIBLE_DEVICES=0 python3 bytesep/train.py train \ + --workspace=$WORKSPACE \ + --gpus=1 \ + --config_yaml=$TRAIN_CONFIG_YAML \ No newline at end of file diff --git a/scripts/5_inference/musdb18/inference.sh b/scripts/5_inference/musdb18/inference.sh new file mode 100755 index 0000000000000000000000000000000000000000..21ecd5a30731343ee9b74e181ef4602b528a87d4 --- /dev/null +++ b/scripts/5_inference/musdb18/inference.sh @@ -0,0 +1,17 @@ +#!/bin/bash +WORKSPACE=${1:-"./workspaces/bytesep"} # The first argument is workspace directory. + +echo "WORKSPACE=${WORKSPACE}" + +# Users can modify the following config file. +TRAIN_CONFIG_YAML="scripts/4_train/musdb18/configs/vocals-accompaniment,unet.yaml" + +CHECKPOINT_PATH="${WORKSPACE}/checkpoints/musdb18/train/config=vocals-accompaniment,unet,gpus=1/step=300000.pth" + +# Inference +CUDA_VISIBLE_DEVICES=0 python3 bytesep/inference.py \ + --config_yaml=$TRAIN_CONFIG_YAML \ + --checkpoint_path=$CHECKPOINT_PATH \ + --audio_path="resources/vocals_accompaniment_10s.mp3" \ + --output_path="sep_results/vocals_accompaniment_10s_sep_vocals.mp3" + \ No newline at end of file diff --git a/scripts/apply-black.sh b/scripts/apply-black.sh new file mode 100755 index 0000000000000000000000000000000000000000..db35f6dd4af7f573770b8614f6dd3448a41909d9 --- /dev/null +++ b/scripts/apply-black.sh @@ -0,0 +1,3 @@ +#!/bin/bash +python3 -m black bytesep + diff --git a/separate_scripts/download_checkpoints.sh b/separate_scripts/download_checkpoints.sh new file mode 100755 index 0000000000000000000000000000000000000000..6d2f3742d92139f2bfdb4e6070980db9af3cead6 --- /dev/null +++ b/separate_scripts/download_checkpoints.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +ZENODO_DIR="https://zenodo.org/record/5507029/files" +DOWNLOADED_CHECKPOINT_DIR="./downloaded_checkpoints" + +mkdir -p $DOWNLOADED_CHECKPOINT_DIR + +MODEL_NAME="resunet143_ismir2021_vocals_8.9dB_350k_steps.pth" +wget -O "${DOWNLOADED_CHECKPOINT_DIR}/${MODEL_NAME}" "${ZENODO_DIR}/${MODEL_NAME}?download=1" + +MODEL_NAME="resunet143_ismir2021_accompaniment_16.8dB_350k_steps.pth" +wget -O "${DOWNLOADED_CHECKPOINT_DIR}/${MODEL_NAME}" "${ZENODO_DIR}/${MODEL_NAME}?download=1" + +MODEL_NAME="resunet143_subbtandtime_vocals_8.8dB_350k_steps.pth" +wget -O "${DOWNLOADED_CHECKPOINT_DIR}/${MODEL_NAME}" "${ZENODO_DIR}/${MODEL_NAME}?download=1" + +MODEL_NAME="resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth" +wget -O "${DOWNLOADED_CHECKPOINT_DIR}/${MODEL_NAME}" "${ZENODO_DIR}/${MODEL_NAME}?download=1" \ No newline at end of file diff --git a/separate_scripts/separate.py b/separate_scripts/separate.py new file mode 100644 index 0000000000000000000000000000000000000000..54561575ac70307e6804072a93d28539ad152c42 --- /dev/null +++ b/separate_scripts/separate.py @@ -0,0 +1,67 @@ +import argparse +import time + +import librosa +import soundfile + +from bytesep.inference import SeparatorWrapper + +sample_rate = 44100 # Must be 44100 when using the downloaded checkpoints. + + +def separate(args): + + audio_path = args.audio_path + source_type = args.source_type + device = "cuda" # "cuda" | "cpu" + + # Load audio. + audio, fs = librosa.load(audio_path, sr=sample_rate, mono=False) + + if audio.ndim == 1: + audio = audio[None, :] + # (2, segment_samples) + + # separator + separator = SeparatorWrapper( + source_type=source_type, + model=None, + checkpoint_path=None, + device=device, + ) + + t1 = time.time() + + # Separate. + sep_wav = separator.separate(audio) + + sep_time = time.time() - t1 + + # Write out audio + sep_audio_path = 'sep_{}.wav'.format(source_type) + + soundfile.write(file=sep_audio_path, data=sep_wav.T, samplerate=sample_rate) + + print("Write out to {}".format(sep_audio_path)) + print("Time: {:.3f}".format(sep_time)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--audio_path', + type=str, + default="resources/vocals_accompaniment_10s.mp3", + help="Audio path", + ) + parser.add_argument( + '--source_type', + type=str, + choices=['vocals', 'accompaniment'], + default="accompaniment", + help="Source type to be separated.", + ) + + args = parser.parse_args() + + separate(args) diff --git a/separate_scripts/separate_accompaniment.sh b/separate_scripts/separate_accompaniment.sh new file mode 100755 index 0000000000000000000000000000000000000000..1953cf152275e99e0f72e376228faa52d69520e4 --- /dev/null +++ b/separate_scripts/separate_accompaniment.sh @@ -0,0 +1,21 @@ +#!/bin/bash +AUDIO_PATH=${1:-"./resources/vocals_accompaniment_10s.mp3"} # The path of audio to be separated. +OUTPUT_PATH=${2:-"./sep_results/sep_vocals.mp3"} # The path to write out separated audio. + +MODEL_NAME="resunet_subbandtime" # "resunet_ismir2021" | ""resunet_subbandtime"" + +if [ $MODEL_NAME = "resunet_ismir2021" ]; then + CHECKPOINT_PATH="./downloaded_checkpoints/resunet143_ismir2021_accompaniment_16.8dB_350k_steps.pth" + TRAIN_CONFIG_YAML="./scripts/4_train/musdb18/configs/accompaniment-vocals,resunet_ismir2021.yaml" + +elif [ $MODEL_NAME = "resunet_subbandtime" ]; then + TRAIN_CONFIG_YAML="./scripts/4_train/musdb18/configs/accompaniment-vocals,resunet_subbandtime.yaml" + CHECKPOINT_PATH="./downloaded_checkpoints/resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth" +fi + +# Inference +CUDA_VISIBLE_DEVICES=0 python3 bytesep/inference.py \ + --config_yaml=$TRAIN_CONFIG_YAML \ + --checkpoint_path=$CHECKPOINT_PATH \ + --audio_path=$AUDIO_PATH \ + --output_path=$OUTPUT_PATH diff --git a/separate_scripts/separate_vocals.sh b/separate_scripts/separate_vocals.sh new file mode 100755 index 0000000000000000000000000000000000000000..be445a415fabcfab04a3f5b73b27493e99d85227 --- /dev/null +++ b/separate_scripts/separate_vocals.sh @@ -0,0 +1,21 @@ +#!/bin/bash +AUDIO_PATH=${1:-"./resources/vocals_accompaniment_10s.mp3"} # The path of audio to be separated. +OUTPUT_PATH=${2:-"./sep_results/sep_vocals.mp3"} # The path to write out separated audio. + +MODEL_NAME="resunet_subbandtime" # "resunet_ismir2021" | ""resunet_subbandtime"" + +if [ $MODEL_NAME = "resunet_ismir2021" ]; then + TRAIN_CONFIG_YAML="./scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_ismir2021.yaml" + CHECKPOINT_PATH="./downloaded_checkpoints/resunet143_ismir2021_vocals_8.9dB_350k_steps.pth" + +elif [ $MODEL_NAME = "resunet_subbandtime" ]; then + TRAIN_CONFIG_YAML="./scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_subbandtime.yaml" + CHECKPOINT_PATH="./downloaded_checkpoints/resunet143_subbtandtime_vocals_8.8dB_350k_steps.pth" +fi + +# Inference +CUDA_VISIBLE_DEVICES=0 python3 bytesep/inference.py \ + --config_yaml=$TRAIN_CONFIG_YAML \ + --checkpoint_path=$CHECKPOINT_PATH \ + --audio_path=$AUDIO_PATH \ + --output_path=$OUTPUT_PATH diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..637b7518000f8a6e2ef894d8e581b2f1ae9f7338 --- /dev/null +++ b/setup.py @@ -0,0 +1,26 @@ +from setuptools import setup + +setup( + name='bytesep', + version='0.0.1', + description='Music source separation', + author='ByteDance', + url="https://github.com/bytedance/music_source_separation", + license='Apache 2.0', + packages=['bytesep'], + include_package_data=True, + install_requires=[ + 'torch==1.7.1', + 'librosa==0.8.0', # specify the version! + 'museval==0.4.0', + 'h5py==2.10.0', + 'pytorch_lightning==1.2.1', + 'numpy==1.18.5', + 'torchlibrosa==0.0.9', + 'matplotlib==3.3.4', + 'musdb==0.4.0', + 'museval==0.4.0', + 'inplace-abn==1.1.0' + ], + zip_safe=False +)