diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c63e41b9efb1831406a3956f42a478f03121bd --- /dev/null +++ b/app.py @@ -0,0 +1,22 @@ +import gradio as gr +import os +DESCRIPTION = """ +# audio sep +being made +""" + +theme = gr.themes.Base( + font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'], +) +with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo: + gr.Markdown(DESCRIPTION) + gr.DuplicateButton( + value="Duplicate Space for private use", + elem_id="duplicate-button", + visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", + ) + + + + +demo.queue(max_size=20, api_open=False).launch(show_api=False) diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..721d1e4d221a0a627b2dc4c79397f85949b861da --- /dev/null +++ b/inference.py @@ -0,0 +1,113 @@ +# coding: utf-8 +__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' + +import argparse +import time +import librosa +from tqdm import tqdm +import sys +import os +import glob +import torch +import numpy as np +import soundfile as sf +import torch.nn as nn +from utils import demix_track, demix_track_demucs, get_model_from_config + +import warnings +warnings.filterwarnings("ignore") + + +def run_folder(model, args, config, device, verbose=False): + start_time = time.time() + model.eval() + all_mixtures_path = glob.glob(args.input_folder + '/*.*') + print('Total files found: {}'.format(len(all_mixtures_path))) + + instruments = config.training.instruments + if config.training.target_instrument is not None: + instruments = [config.training.target_instrument] + + if not os.path.isdir(args.store_dir): + os.mkdir(args.store_dir) + + if not verbose: + all_mixtures_path = tqdm(all_mixtures_path) + + for path in all_mixtures_path: + if not verbose: + all_mixtures_path.set_postfix({'track': os.path.basename(path)}) + try: + # mix, sr = sf.read(path) + mix, sr = librosa.load(path, sr=44100, mono=False) + mix = mix.T + except Exception as e: + print('Can read track: {}'.format(path)) + print('Error message: {}'.format(str(e))) + continue + + # Convert mono to stereo if needed + if len(mix.shape) == 1: + mix = np.stack([mix, mix], axis=-1) + + mixture = torch.tensor(mix.T, dtype=torch.float32) + if args.model_type == 'htdemucs': + res = demix_track_demucs(config, model, mixture, device) + else: + res = demix_track(config, model, mixture, device) + for instr in instruments: + sf.write("{}/{}_{}.wav".format(args.store_dir, os.path.basename(path)[:-4], instr), res[instr].T, sr, subtype='FLOAT') + + if 'vocals' in instruments and args.extract_instrumental: + instrum_file_name = "{}/{}_{}.wav".format(args.store_dir, os.path.basename(path)[:-4], 'instrumental') + sf.write(instrum_file_name, mix - res['vocals'].T, sr, subtype='FLOAT') + + time.sleep(1) + print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) + + +def proc_folder(args): + parser = argparse.ArgumentParser() + parser.add_argument("--model_type", type=str, default='mdx23c', help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit") + parser.add_argument("--config_path", type=str, help="path to config file") + parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to valid weights") + parser.add_argument("--input_folder", type=str, help="folder with mixtures to process") + parser.add_argument("--store_dir", default="", type=str, help="path to store results as wav file") + parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='list of gpu ids') + parser.add_argument("--extract_instrumental", action='store_true', help="invert vocals to get instrumental if provided") + if args is None: + args = parser.parse_args() + else: + args = parser.parse_args(args) + + torch.backends.cudnn.benchmark = True + + model, config = get_model_from_config(args.model_type, args.config_path) + if args.start_check_point != '': + print('Start from checkpoint: {}'.format(args.start_check_point)) + state_dict = torch.load(args.start_check_point) + if args.model_type == 'htdemucs': + # Fix for htdemucs pround etrained models + if 'state' in state_dict: + state_dict = state_dict['state'] + model.load_state_dict(state_dict) + print("Instruments: {}".format(config.training.instruments)) + + if torch.cuda.is_available(): + device_ids = args.device_ids + if type(device_ids)==int: + device = torch.device(f'cuda:{device_ids}') + model = model.to(device) + else: + device = torch.device(f'cuda:{device_ids[0]}') + model = nn.DataParallel(model, device_ids=device_ids).to(device) + else: + device = 'cpu' + print('CUDA is not avilable. Run inference on CPU. It will be very slow...') + model = model.to(device) + + run_folder(model, args, config, device, verbose=False) + + +if __name__ == "__main__": + proc_folder(None) diff --git a/models/bandit/core/__init__.py b/models/bandit/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..695925757232c84598a2ec7208e73a17fae53b7e --- /dev/null +++ b/models/bandit/core/__init__.py @@ -0,0 +1,744 @@ +import os.path +from collections import defaultdict +from itertools import chain, combinations +from typing import ( + Any, + Dict, + Iterator, + Mapping, Optional, + Tuple, Type, + TypedDict +) + +import pytorch_lightning as pl +import torch +import torchaudio as ta +import torchmetrics as tm +from asteroid import losses as asteroid_losses +# from deepspeed.ops.adam import DeepSpeedCPUAdam +# from geoopt import optim as gooptim +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import nn, optim +from torch.optim import lr_scheduler +from torch.optim.lr_scheduler import LRScheduler + +from models.bandit.core import loss, metrics as metrics_, model +from models.bandit.core.data._types import BatchedDataDict +from models.bandit.core.data.augmentation import BaseAugmentor, StemAugmentor +from models.bandit.core.utils import audio as audio_ +from models.bandit.core.utils.audio import BaseFader + +# from pandas.io.json._normalize import nested_to_record + +ConfigDict = TypedDict('ConfigDict', {'name': str, 'kwargs': Dict[str, Any]}) + + +class SchedulerConfigDict(ConfigDict): + monitor: str + + +OptimizerSchedulerConfigDict = TypedDict( + 'OptimizerSchedulerConfigDict', + {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict}, + total=False +) + + +class LRSchedulerReturnDict(TypedDict, total=False): + scheduler: LRScheduler + monitor: str + + +class ConfigureOptimizerReturnDict(TypedDict, total=False): + optimizer: torch.optim.Optimizer + lr_scheduler: LRSchedulerReturnDict + + +OutputType = Dict[str, Any] +MetricsType = Dict[str, torch.Tensor] + + +def get_optimizer_class(name: str) -> Type[optim.Optimizer]: + + if name == "DeepSpeedCPUAdam": + return DeepSpeedCPUAdam + + for module in [optim, gooptim]: + if name in module.__dict__: + return module.__dict__[name] + + raise NameError + + +def parse_optimizer_config( + config: OptimizerSchedulerConfigDict, + parameters: Iterator[nn.Parameter] +) -> ConfigureOptimizerReturnDict: + optim_class = get_optimizer_class(config["optimizer"]["name"]) + optimizer = optim_class(parameters, **config["optimizer"]["kwargs"]) + + optim_dict: ConfigureOptimizerReturnDict = { + "optimizer": optimizer, + } + + if "scheduler" in config: + + lr_scheduler_class_ = config["scheduler"]["name"] + lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_] + lr_scheduler_dict: LRSchedulerReturnDict = { + "scheduler": lr_scheduler_class( + optimizer, + **config["scheduler"]["kwargs"] + ) + } + + if lr_scheduler_class_ == "ReduceLROnPlateau": + lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"] + + optim_dict["lr_scheduler"] = lr_scheduler_dict + + return optim_dict + + +def parse_model_config(config: ConfigDict) -> Any: + name = config["name"] + + for module in [model]: + if name in module.__dict__: + return module.__dict__[name](**config["kwargs"]) + + raise NameError + + +_LEGACY_LOSS_NAMES = ["HybridL1Loss"] + + +def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module: + name = config["name"] + + if name == "HybridL1Loss": + return loss.TimeFreqL1Loss(**config["kwargs"]) + + raise NameError + + +def parse_loss_config(config: ConfigDict) -> nn.Module: + name = config["name"] + + if name in _LEGACY_LOSS_NAMES: + return _parse_legacy_loss_config(config) + + for module in [loss, nn.modules.loss, asteroid_losses]: + if name in module.__dict__: + # print(config["kwargs"]) + return module.__dict__[name](**config["kwargs"]) + + raise NameError + + +def get_metric(config: ConfigDict) -> tm.Metric: + name = config["name"] + + for module in [tm, metrics_]: + if name in module.__dict__: + return module.__dict__[name](**config["kwargs"]) + raise NameError + + +def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection: + metrics = {} + + for metric in config: + metrics[metric] = get_metric(config[metric]) + + return tm.MetricCollection(metrics) + + +def parse_fader_config(config: ConfigDict) -> BaseFader: + name = config["name"] + + for module in [audio_]: + if name in module.__dict__: + return module.__dict__[name](**config["kwargs"]) + + raise NameError + + +class LightningSystem(pl.LightningModule): + _VOX_STEMS = ["speech", "vocals"] + _BG_STEMS = ["background", "effects", "mne"] + + def __init__( + self, + config: Dict, + loss_adjustment: float = 1.0, + attach_fader: bool = False + ) -> None: + super().__init__() + self.optimizer_config = config["optimizer"] + self.model = parse_model_config(config["model"]) + self.loss = parse_loss_config(config["loss"]) + self.metrics = nn.ModuleDict( + { + stem: parse_metric_config(config["metrics"]["dev"]) + for stem in self.model.stems + } + ) + + self.metrics.disallow_fsdp = True + + self.test_metrics = nn.ModuleDict( + { + stem: parse_metric_config(config["metrics"]["test"]) + for stem in self.model.stems + } + ) + + self.test_metrics.disallow_fsdp = True + + self.fs = config["model"]["kwargs"]["fs"] + + self.fader_config = config["inference"]["fader"] + if attach_fader: + self.fader = parse_fader_config(config["inference"]["fader"]) + else: + self.fader = None + + self.augmentation: Optional[BaseAugmentor] + if config.get("augmentation", None) is not None: + self.augmentation = StemAugmentor(**config["augmentation"]) + else: + self.augmentation = None + + self.predict_output_path: Optional[str] = None + self.loss_adjustment = loss_adjustment + + self.val_prefix = None + self.test_prefix = None + + + def configure_optimizers(self) -> Any: + return parse_optimizer_config( + self.optimizer_config, + self.trainer.model.parameters() + ) + + def compute_loss(self, batch: BatchedDataDict, output: OutputType) -> Dict[ + str, torch.Tensor]: + return {"loss": self.loss(output, batch)} + + def update_metrics( + self, + batch: BatchedDataDict, + output: OutputType, + mode: str + ) -> None: + + if mode == "test": + metrics = self.test_metrics + else: + metrics = self.metrics + + for stem, metric in metrics.items(): + + if stem == "mne:+": + stem = "mne" + + # print(f"matching for {stem}") + if mode == "train": + metric.update( + output["audio"][stem],#.cpu(), + batch["audio"][stem],#.cpu() + ) + else: + if stem not in batch["audio"]: + matched = False + if stem in self._VOX_STEMS: + for bstem in self._VOX_STEMS: + if bstem in batch["audio"]: + batch["audio"][stem] = batch["audio"][bstem] + matched = True + break + elif stem in self._BG_STEMS: + for bstem in self._BG_STEMS: + if bstem in batch["audio"]: + batch["audio"][stem] = batch["audio"][bstem] + matched = True + break + else: + matched = True + + # print(batch["audio"].keys()) + + if matched: + # print(f"matched {stem}!") + if stem == "mne" and "mne" not in output["audio"]: + output["audio"]["mne"] = output["audio"]["music"] + output["audio"]["effects"] + + metric.update( + output["audio"][stem],#.cpu(), + batch["audio"][stem],#.cpu(), + ) + + # print(metric.compute()) + def compute_metrics(self, mode: str="dev") -> Dict[ + str, torch.Tensor]: + + if mode == "test": + metrics = self.test_metrics + else: + metrics = self.metrics + + metric_dict = {} + + for stem, metric in metrics.items(): + md = metric.compute() + metric_dict.update( + {f"{stem}/{k}": v for k, v in md.items()} + ) + + self.log_dict(metric_dict, prog_bar=True, logger=False) + + return metric_dict + + def reset_metrics(self, test_mode: bool = False) -> None: + + if test_mode: + metrics = self.test_metrics + else: + metrics = self.metrics + + for _, metric in metrics.items(): + metric.reset() + + + def forward(self, batch: BatchedDataDict) -> Any: + batch, output = self.model(batch) + + + return batch, output + + def common_step(self, batch: BatchedDataDict, mode: str) -> Any: + batch, output = self.forward(batch) + # print(batch) + # print(output) + loss_dict = self.compute_loss(batch, output) + + with torch.no_grad(): + self.update_metrics(batch, output, mode=mode) + + if mode == "train": + self.log("loss", loss_dict["loss"], prog_bar=True) + + return output, loss_dict + + + def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]: + + if self.augmentation is not None: + with torch.no_grad(): + batch = self.augmentation(batch) + + _, loss_dict = self.common_step(batch, mode="train") + + with torch.inference_mode(): + self.log_dict_with_prefix( + loss_dict, + "train", + batch_size=batch["audio"]["mixture"].shape[0] + ) + + loss_dict["loss"] *= self.loss_adjustment + + return loss_dict + + def on_train_batch_end( + self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int + ) -> None: + + metric_dict = self.compute_metrics() + self.log_dict_with_prefix(metric_dict, "train") + self.reset_metrics() + + def validation_step( + self, + batch: BatchedDataDict, + batch_idx: int, + dataloader_idx: int = 0 + ) -> Dict[str, Any]: + + with torch.inference_mode(): + curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val" + + if curr_val_prefix != self.val_prefix: + # print(f"Switching to validation dataloader {dataloader_idx}") + if self.val_prefix is not None: + self._on_validation_epoch_end() + self.val_prefix = curr_val_prefix + _, loss_dict = self.common_step(batch, mode="val") + + self.log_dict_with_prefix( + loss_dict, + self.val_prefix, + batch_size=batch["audio"]["mixture"].shape[0], + prog_bar=True, + add_dataloader_idx=False + ) + + return loss_dict + + def on_validation_epoch_end(self) -> None: + self._on_validation_epoch_end() + + def _on_validation_epoch_end(self) -> None: + metric_dict = self.compute_metrics() + self.log_dict_with_prefix(metric_dict, self.val_prefix, prog_bar=True, + add_dataloader_idx=False) + # self.logger.save() + # print(self.val_prefix, "Validation metrics:", metric_dict) + self.reset_metrics() + + + def old_predtest_step( + self, + batch: BatchedDataDict, + batch_idx: int, + dataloader_idx: int = 0 + ) -> Tuple[BatchedDataDict, OutputType]: + + audio_batch = batch["audio"]["mixture"] + track_batch = batch.get("track", ["" for _ in range(len(audio_batch))]) + + output_list_of_dicts = [ + self.fader( + audio[None, ...], + lambda a: self.test_forward(a, track) + ) + for audio, track in zip(audio_batch, track_batch) + ] + + output_dict_of_lists = defaultdict(list) + + for output_dict in output_list_of_dicts: + for stem, audio in output_dict.items(): + output_dict_of_lists[stem].append(audio) + + output = { + "audio": { + stem: torch.concat(output_list, dim=0) + for stem, output_list in output_dict_of_lists.items() + } + } + + return batch, output + + def predtest_step( + self, + batch: BatchedDataDict, + batch_idx: int = -1, + dataloader_idx: int = 0 + ) -> Tuple[BatchedDataDict, OutputType]: + + if getattr(self.model, "bypass_fader", False): + batch, output = self.model(batch) + else: + audio_batch = batch["audio"]["mixture"] + output = self.fader( + audio_batch, + lambda a: self.test_forward(a, "", batch=batch) + ) + + return batch, output + + def test_forward( + self, + audio: torch.Tensor, + track: str = "", + batch: BatchedDataDict = None + ) -> torch.Tensor: + + if self.fader is None: + self.attach_fader() + + cond = batch.get("condition", None) + + if cond is not None and cond.shape[0] == 1: + cond = cond.repeat(audio.shape[0], 1) + + _, output = self.forward( + {"audio": {"mixture": audio}, + "track": track, + "condition": cond, + } + ) # TODO: support track properly + + return output["audio"] + + def on_test_epoch_start(self) -> None: + self.attach_fader(force_reattach=True) + + def test_step( + self, + batch: BatchedDataDict, + batch_idx: int, + dataloader_idx: int = 0 + ) -> Any: + curr_test_prefix = f"test{dataloader_idx}" + + # print(batch["audio"].keys()) + + if curr_test_prefix != self.test_prefix: + # print(f"Switching to test dataloader {dataloader_idx}") + if self.test_prefix is not None: + self._on_test_epoch_end() + self.test_prefix = curr_test_prefix + + with torch.inference_mode(): + _, output = self.predtest_step(batch, batch_idx, dataloader_idx) + # print(output) + self.update_metrics(batch, output, mode="test") + + return output + + def on_test_epoch_end(self) -> None: + self._on_test_epoch_end() + + def _on_test_epoch_end(self) -> None: + metric_dict = self.compute_metrics(mode="test") + self.log_dict_with_prefix(metric_dict, self.test_prefix, prog_bar=True, + add_dataloader_idx=False) + # self.logger.save() + # print(self.test_prefix, "Test metrics:", metric_dict) + self.reset_metrics() + + def predict_step( + self, + batch: BatchedDataDict, + batch_idx: int = 0, + dataloader_idx: int = 0, + include_track_name: Optional[bool] = None, + get_no_vox_combinations: bool = True, + get_residual: bool = False, + treat_batch_as_channels: bool = False, + fs: Optional[int] = None, + ) -> Any: + assert self.predict_output_path is not None + + batch_size = batch["audio"]["mixture"].shape[0] + + if include_track_name is None: + include_track_name = batch_size > 1 + + with torch.inference_mode(): + batch, output = self.predtest_step(batch, batch_idx, dataloader_idx) + print('Pred test finished...') + torch.cuda.empty_cache() + metric_dict = {} + + if get_residual: + mixture = batch["audio"]["mixture"] + extracted = sum([output["audio"][stem] for stem in output["audio"]]) + residual = mixture - extracted + print(extracted.shape, mixture.shape, residual.shape) + + output["audio"]["residual"] = residual + + if get_no_vox_combinations: + no_vox_stems = [ + stem for stem in output["audio"] if + stem not in self._VOX_STEMS + ] + no_vox_combinations = chain.from_iterable( + combinations(no_vox_stems, r) for r in + range(2, len(no_vox_stems) + 1) + ) + + for combination in no_vox_combinations: + combination_ = list(combination) + output["audio"]["+".join(combination_)] = sum( + [output["audio"][stem] for stem in combination_] + ) + + if treat_batch_as_channels: + for stem in output["audio"]: + output["audio"][stem] = output["audio"][stem].reshape( + 1, -1, output["audio"][stem].shape[-1] + ) + batch_size = 1 + + for b in range(batch_size): + print("!!", b) + for stem in output["audio"]: + print(f"Saving audio for {stem} to {self.predict_output_path}") + track_name = batch["track"][b].split("/")[-1] + + if batch.get("audio", {}).get(stem, None) is not None: + self.test_metrics[stem].reset() + metrics = self.test_metrics[stem]( + batch["audio"][stem][[b], ...], + output["audio"][stem][[b], ...] + ) + snr = metrics["snr"] + sisnr = metrics["sisnr"] + sdr = metrics["sdr"] + metric_dict[stem] = metrics + print( + track_name, + f"snr={snr:2.2f} dB", + f"sisnr={sisnr:2.2f}", + f"sdr={sdr:2.2f} dB", + ) + filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav" + else: + filename = f"{stem}.wav" + + if include_track_name: + output_dir = os.path.join( + self.predict_output_path, + track_name + ) + else: + output_dir = self.predict_output_path + + os.makedirs(output_dir, exist_ok=True) + + if fs is None: + fs = self.fs + + ta.save( + os.path.join(output_dir, filename), + output["audio"][stem][b, ...].cpu(), + fs, + ) + + return metric_dict + + def get_stems( + self, + batch: BatchedDataDict, + batch_idx: int = 0, + dataloader_idx: int = 0, + include_track_name: Optional[bool] = None, + get_no_vox_combinations: bool = True, + get_residual: bool = False, + treat_batch_as_channels: bool = False, + fs: Optional[int] = None, + ) -> Any: + assert self.predict_output_path is not None + + batch_size = batch["audio"]["mixture"].shape[0] + + if include_track_name is None: + include_track_name = batch_size > 1 + + with torch.inference_mode(): + batch, output = self.predtest_step(batch, batch_idx, dataloader_idx) + torch.cuda.empty_cache() + metric_dict = {} + + if get_residual: + mixture = batch["audio"]["mixture"] + extracted = sum([output["audio"][stem] for stem in output["audio"]]) + residual = mixture - extracted + # print(extracted.shape, mixture.shape, residual.shape) + + output["audio"]["residual"] = residual + + if get_no_vox_combinations: + no_vox_stems = [ + stem for stem in output["audio"] if + stem not in self._VOX_STEMS + ] + no_vox_combinations = chain.from_iterable( + combinations(no_vox_stems, r) for r in + range(2, len(no_vox_stems) + 1) + ) + + for combination in no_vox_combinations: + combination_ = list(combination) + output["audio"]["+".join(combination_)] = sum( + [output["audio"][stem] for stem in combination_] + ) + + if treat_batch_as_channels: + for stem in output["audio"]: + output["audio"][stem] = output["audio"][stem].reshape( + 1, -1, output["audio"][stem].shape[-1] + ) + batch_size = 1 + + result = {} + for b in range(batch_size): + for stem in output["audio"]: + track_name = batch["track"][b].split("/")[-1] + + if batch.get("audio", {}).get(stem, None) is not None: + self.test_metrics[stem].reset() + metrics = self.test_metrics[stem]( + batch["audio"][stem][[b], ...], + output["audio"][stem][[b], ...] + ) + snr = metrics["snr"] + sisnr = metrics["sisnr"] + sdr = metrics["sdr"] + metric_dict[stem] = metrics + print( + track_name, + f"snr={snr:2.2f} dB", + f"sisnr={sisnr:2.2f}", + f"sdr={sdr:2.2f} dB", + ) + filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav" + else: + filename = f"{stem}.wav" + + if include_track_name: + output_dir = os.path.join( + self.predict_output_path, + track_name + ) + else: + output_dir = self.predict_output_path + + os.makedirs(output_dir, exist_ok=True) + + if fs is None: + fs = self.fs + + result[stem] = output["audio"][stem][b, ...].cpu().numpy() + + return result + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = False + ) -> Any: + + return super().load_state_dict(state_dict, strict=False) + + + def set_predict_output_path(self, path: str) -> None: + self.predict_output_path = path + os.makedirs(self.predict_output_path, exist_ok=True) + + self.attach_fader() + + def attach_fader(self, force_reattach=False) -> None: + if self.fader is None or force_reattach: + self.fader = parse_fader_config(self.fader_config) + self.fader.to(self.device) + + + def log_dict_with_prefix( + self, + dict_: Dict[str, torch.Tensor], + prefix: str, + batch_size: Optional[int] = None, + **kwargs: Any + ) -> None: + self.log_dict( + {f"{prefix}/{k}": v for k, v in dict_.items()}, + batch_size=batch_size, + logger=True, + sync_dist=True, + **kwargs, + ) \ No newline at end of file diff --git a/models/bandit/core/data/__init__.py b/models/bandit/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10e9cf50980d80985eaacf96a2ed87c4e84d603f --- /dev/null +++ b/models/bandit/core/data/__init__.py @@ -0,0 +1,2 @@ +from .dnr.datamodule import DivideAndRemasterDataModule +from .musdb.datamodule import MUSDB18DataModule \ No newline at end of file diff --git a/models/bandit/core/data/_types.py b/models/bandit/core/data/_types.py new file mode 100644 index 0000000000000000000000000000000000000000..dd45012700b98c85563826b09d5f6ccc5ddd214e --- /dev/null +++ b/models/bandit/core/data/_types.py @@ -0,0 +1,18 @@ +from typing import Dict, Sequence, TypedDict + +import torch + +AudioDict = Dict[str, torch.Tensor] + +DataDict = TypedDict('DataDict', {'audio': AudioDict, 'track': str}) + +BatchedDataDict = TypedDict( + 'BatchedDataDict', + {'audio': AudioDict, 'track': Sequence[str]} +) + + +class DataDictWithLanguage(TypedDict): + audio: AudioDict + track: str + language: str diff --git a/models/bandit/core/data/augmentation.py b/models/bandit/core/data/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..1737c91a512104f87e564b2ed56fccc74e632b1c --- /dev/null +++ b/models/bandit/core/data/augmentation.py @@ -0,0 +1,107 @@ +from abc import ABC +from typing import Any, Dict, Union + +import torch +import torch_audiomentations as tam +from torch import nn + +from models.bandit.core.data._types import BatchedDataDict, DataDict + + +class BaseAugmentor(nn.Module, ABC): + def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[ + DataDict, BatchedDataDict]: + raise NotImplementedError + + +class StemAugmentor(BaseAugmentor): + def __init__( + self, + audiomentations: Dict[str, Dict[str, Any]], + fix_clipping: bool = True, + scaler_margin: float = 0.5, + apply_both_default_and_common: bool = False, + ) -> None: + super().__init__() + + augmentations = {} + + self.has_default = "[default]" in audiomentations + self.has_common = "[common]" in audiomentations + self.apply_both_default_and_common = apply_both_default_and_common + + for stem in audiomentations: + if audiomentations[stem]["name"] == "Compose": + augmentations[stem] = getattr( + tam, + audiomentations[stem]["name"] + )( + [ + getattr(tam, aug["name"])(**aug["kwargs"]) + for aug in + audiomentations[stem]["kwargs"]["transforms"] + ], + **audiomentations[stem]["kwargs"]["kwargs"], + ) + else: + augmentations[stem] = getattr( + tam, + audiomentations[stem]["name"] + )( + **audiomentations[stem]["kwargs"] + ) + + self.augmentations = nn.ModuleDict(augmentations) + self.fix_clipping = fix_clipping + self.scaler_margin = scaler_margin + + def check_and_fix_clipping( + self, item: Union[DataDict, BatchedDataDict] + ) -> Union[DataDict, BatchedDataDict]: + max_abs = [] + + for stem in item["audio"]: + max_abs.append(item["audio"][stem].abs().max().item()) + + if max(max_abs) > 1.0: + scaler = 1.0 / (max(max_abs) + torch.rand( + (1,), + device=item["audio"]["mixture"].device + ) * self.scaler_margin) + + for stem in item["audio"]: + item["audio"][stem] *= scaler + + return item + + def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[ + DataDict, BatchedDataDict]: + + for stem in item["audio"]: + if stem == "mixture": + continue + + if self.has_common: + item["audio"][stem] = self.augmentations["[common]"]( + item["audio"][stem] + ).samples + + if stem in self.augmentations: + item["audio"][stem] = self.augmentations[stem]( + item["audio"][stem] + ).samples + elif self.has_default: + if not self.has_common or self.apply_both_default_and_common: + item["audio"][stem] = self.augmentations["[default]"]( + item["audio"][stem] + ).samples + + item["audio"]["mixture"] = sum( + [item["audio"][stem] for stem in item["audio"] + if stem != "mixture"] + ) # type: ignore[call-overload, assignment] + + if self.fix_clipping: + item = self.check_and_fix_clipping(item) + + return item diff --git a/models/bandit/core/data/augmented.py b/models/bandit/core/data/augmented.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0807d4366eefc131fd762cae54a5322f36877a --- /dev/null +++ b/models/bandit/core/data/augmented.py @@ -0,0 +1,35 @@ +import warnings +from typing import Dict, Optional, Union + +import torch +from torch import nn +from torch.utils import data + + +class AugmentedDataset(data.Dataset): + def __init__( + self, + dataset: data.Dataset, + augmentation: nn.Module = nn.Identity(), + target_length: Optional[int] = None, + ) -> None: + warnings.warn( + "This class is no longer used. Attach augmentation to " + "the LightningSystem instead.", + DeprecationWarning, + ) + + self.dataset = dataset + self.augmentation = augmentation + + self.ds_length: int = len(dataset) # type: ignore[arg-type] + self.length = target_length if target_length is not None else self.ds_length + + def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, + torch.Tensor]]]: + item = self.dataset[index % self.ds_length] + item = self.augmentation(item) + return item + + def __len__(self) -> int: + return self.length diff --git a/models/bandit/core/data/base.py b/models/bandit/core/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9e77e8f8675c786c9b9bdcda66ad29744daa8755 --- /dev/null +++ b/models/bandit/core/data/base.py @@ -0,0 +1,69 @@ +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import numpy as np +import pedalboard as pb +import torch +import torchaudio as ta +from torch.utils import data + +from models.bandit.core.data._types import AudioDict, DataDict + + +class BaseSourceSeparationDataset(data.Dataset, ABC): + def __init__( + self, split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int, + npy_memmap: bool, + recompute_mixture: bool + ): + self.split = split + self.stems = stems + self.stems_no_mixture = [s for s in stems if s != "mixture"] + self.files = files + self.data_path = data_path + self.fs = fs + self.npy_memmap = npy_memmap + self.recompute_mixture = recompute_mixture + + @abstractmethod + def get_stem( + self, + *, + stem: str, + identifier: Dict[str, Any] + ) -> torch.Tensor: + raise NotImplementedError + + def _get_audio(self, stems, identifier: Dict[str, Any]): + audio = {} + for stem in stems: + audio[stem] = self.get_stem(stem=stem, identifier=identifier) + + return audio + + def get_audio(self, identifier: Dict[str, Any]) -> AudioDict: + + if self.recompute_mixture: + audio = self._get_audio( + self.stems_no_mixture, + identifier=identifier + ) + audio["mixture"] = self.compute_mixture(audio) + return audio + else: + return self._get_audio(self.stems, identifier=identifier) + + @abstractmethod + def get_identifier(self, index: int) -> Dict[str, Any]: + pass + + def compute_mixture(self, audio: AudioDict) -> torch.Tensor: + + return sum( + audio[stem] for stem in audio if stem != "mixture" + ) diff --git a/models/bandit/core/data/dnr/__init__.py b/models/bandit/core/data/dnr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/bandit/core/data/dnr/datamodule.py b/models/bandit/core/data/dnr/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..5bdd77a56c35c79fb473cd34b53b4a4b0a01259f --- /dev/null +++ b/models/bandit/core/data/dnr/datamodule.py @@ -0,0 +1,74 @@ +import os +from typing import Mapping, Optional + +import pytorch_lightning as pl + +from .dataset import ( + DivideAndRemasterDataset, + DivideAndRemasterDeterministicChunkDataset, + DivideAndRemasterRandomChunkDataset, + DivideAndRemasterRandomChunkDatasetWithSpeechReverb +) + + +def DivideAndRemasterDataModule( + data_root: str = "$DATA_ROOT/DnR/v2", + batch_size: int = 2, + num_workers: int = 8, + train_kwargs: Optional[Mapping] = None, + val_kwargs: Optional[Mapping] = None, + test_kwargs: Optional[Mapping] = None, + datamodule_kwargs: Optional[Mapping] = None, + use_speech_reverb: bool = False + # augmentor=None +) -> pl.LightningDataModule: + if train_kwargs is None: + train_kwargs = {} + + if val_kwargs is None: + val_kwargs = {} + + if test_kwargs is None: + test_kwargs = {} + + if datamodule_kwargs is None: + datamodule_kwargs = {} + + if num_workers is None: + num_workers = os.cpu_count() + + if num_workers is None: + num_workers = 32 + + num_workers = min(num_workers, 64) + + if use_speech_reverb: + train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb + else: + train_cls = DivideAndRemasterRandomChunkDataset + + train_dataset = train_cls( + data_root, "train", **train_kwargs + ) + + # if augmentor is not None: + # train_dataset = AugmentedDataset(train_dataset, augmentor) + + datamodule = pl.LightningDataModule.from_datasets( + train_dataset=train_dataset, + val_dataset=DivideAndRemasterDeterministicChunkDataset( + data_root, "val", **val_kwargs + ), + test_dataset=DivideAndRemasterDataset( + data_root, + "test", + **test_kwargs + ), + batch_size=batch_size, + num_workers=num_workers, + **datamodule_kwargs + ) + + datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign] + + return datamodule diff --git a/models/bandit/core/data/dnr/dataset.py b/models/bandit/core/data/dnr/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ed076934e8e1ea3c14ffc44d2150a9285459ca6f --- /dev/null +++ b/models/bandit/core/data/dnr/dataset.py @@ -0,0 +1,392 @@ +import os +from abc import ABC +from typing import Any, Dict, List, Optional + +import numpy as np +import pedalboard as pb +import torch +import torchaudio as ta +from torch.utils import data + +from models.bandit.core.data._types import AudioDict, DataDict +from models.bandit.core.data.base import BaseSourceSeparationDataset + + +class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC): + ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"] + STEM_NAME_MAP = { + "mixture": "mix", + "speech": "speech", + "music": "music", + "effects": "sfx", + } + SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"} + + FULL_TRACK_LENGTH_SECOND = 60 + FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100 + + def __init__( + self, + split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int = 44100, + npy_memmap: bool = True, + recompute_mixture: bool = False, + ) -> None: + super().__init__( + split=split, + stems=stems, + files=files, + data_path=data_path, + fs=fs, + npy_memmap=npy_memmap, + recompute_mixture=recompute_mixture + ) + + def get_stem( + self, + *, + stem: str, + identifier: Dict[str, Any] + ) -> torch.Tensor: + + if stem == "mne": + return self.get_stem( + stem="music", + identifier=identifier) + self.get_stem( + stem="effects", + identifier=identifier) + + track = identifier["track"] + path = os.path.join(self.data_path, track) + + if self.npy_memmap: + audio = np.load( + os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), + mmap_mode="r" + ) + else: + # noinspection PyUnresolvedReferences + audio, _ = ta.load( + os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav") + ) + + return audio + + def get_identifier(self, index): + return dict(track=self.files[index]) + + def __getitem__(self, index: int) -> DataDict: + identifier = self.get_identifier(index) + audio = self.get_audio(identifier) + + return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} + + +class DivideAndRemasterDataset(DivideAndRemasterBaseDataset): + def __init__( + self, + data_root: str, + split: str, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split]) + + files = sorted(os.listdir(data_path)) + files = [ + f + for f in files + if (not f.startswith(".")) and os.path.isdir( + os.path.join(data_path, f) + ) + ] + # pprint(list(enumerate(files))) + if split == "train": + assert len(files) == 3406, len(files) + elif split == "val": + assert len(files) == 487, len(files) + elif split == "test": + assert len(files) == 973, len(files) + + self.n_tracks = len(files) + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, + ) + + def __len__(self) -> int: + return self.n_tracks + + +class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset): + def __init__( + self, + data_root: str, + split: str, + target_length: int, + chunk_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split]) + + files = sorted(os.listdir(data_path)) + files = [ + f + for f in files + if (not f.startswith(".")) and os.path.isdir( + os.path.join(data_path, f) + ) + ] + + if split == "train": + assert len(files) == 3406, len(files) + elif split == "val": + assert len(files) == 487, len(files) + elif split == "test": + assert len(files) == 973, len(files) + + self.n_tracks = len(files) + + self.target_length = target_length + self.chunk_size = int(chunk_size_second * fs) + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, + ) + + def __len__(self) -> int: + return self.target_length + + def get_identifier(self, index): + return super().get_identifier(index % self.n_tracks) + + def get_stem( + self, + *, + stem: str, + identifier: Dict[str, Any], + chunk_here: bool = False, + ) -> torch.Tensor: + + stem = super().get_stem( + stem=stem, + identifier=identifier + ) + + if chunk_here: + start = np.random.randint( + 0, + self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size + ) + end = start + self.chunk_size + + stem = stem[:, start:end] + + return stem + + def __getitem__(self, index: int) -> DataDict: + identifier = self.get_identifier(index) + # self.index_lock = index + audio = self.get_audio(identifier) + # self.index_lock = None + + start = np.random.randint( + 0, + self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size + ) + end = start + self.chunk_size + + audio = { + k: v[:, start:end] for k, v in audio.items() + } + + return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} + + +class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset): + def __init__( + self, + data_root: str, + split: str, + chunk_size_second: float, + hop_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split]) + + files = sorted(os.listdir(data_path)) + files = [ + f + for f in files + if (not f.startswith(".")) and os.path.isdir( + os.path.join(data_path, f) + ) + ] + # pprint(list(enumerate(files))) + if split == "train": + assert len(files) == 3406, len(files) + elif split == "val": + assert len(files) == 487, len(files) + elif split == "test": + assert len(files) == 973, len(files) + + self.n_tracks = len(files) + + self.chunk_size = int(chunk_size_second * fs) + self.hop_size = int(hop_size_second * fs) + self.n_chunks_per_track = int( + ( + self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second + ) + + self.length = self.n_tracks * self.n_chunks_per_track + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, + ) + + def get_identifier(self, index): + return super().get_identifier(index % self.n_tracks) + + def __len__(self) -> int: + return self.length + + def __getitem__(self, item: int) -> DataDict: + + index = item % self.n_tracks + chunk = item // self.n_tracks + + data_ = super().__getitem__(index) + + audio = data_["audio"] + + start = chunk * self.hop_size + end = start + self.chunk_size + + for stem in self.stems: + data_["audio"][stem] = audio[stem][:, start:end] + + return data_ + + +class DivideAndRemasterRandomChunkDatasetWithSpeechReverb( + DivideAndRemasterRandomChunkDataset +): + def __init__( + self, + data_root: str, + split: str, + target_length: int, + chunk_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + + stems_no_mixture = [s for s in stems if s != "mixture"] + + super().__init__( + data_root=data_root, + split=split, + target_length=target_length, + chunk_size_second=chunk_size_second, + stems=stems_no_mixture, + fs=fs, + npy_memmap=npy_memmap, + ) + + self.stems = stems + self.stems_no_mixture = stems_no_mixture + + def __getitem__(self, index: int) -> DataDict: + + data_ = super().__getitem__(index) + + dry = data_["audio"]["speech"][:] + n_samples = dry.shape[-1] + + wet_level = np.random.rand() + + speech = pb.Reverb( + room_size=np.random.rand(), + damping=np.random.rand(), + wet_level=wet_level, + dry_level=(1 - wet_level), + width=np.random.rand() + ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples] + + data_["audio"]["speech"] = speech + + data_["audio"]["mixture"] = sum( + [data_["audio"][s] for s in self.stems_no_mixture] + ) + + return data_ + + def __len__(self) -> int: + return super().__len__() + + +if __name__ == "__main__": + + from pprint import pprint + from tqdm import tqdm + + for split_ in ["train", "val", "test"]: + ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb( + data_root="$DATA_ROOT/DnR/v2np", + split=split_, + target_length=100, + chunk_size_second=6.0 + ) + + print(split_, len(ds)) + + for track_ in tqdm(ds): # type: ignore + pprint(track_) + track_["audio"] = {k: v.shape for k, v in track_["audio"].items()} + pprint(track_) + # break + + break diff --git a/models/bandit/core/data/dnr/preprocess.py b/models/bandit/core/data/dnr/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..53cbdd8b4702c3ec7a81f4847485380b0ff3fddf --- /dev/null +++ b/models/bandit/core/data/dnr/preprocess.py @@ -0,0 +1,54 @@ +import glob +import os +from typing import Tuple + +import numpy as np +import torchaudio as ta +from tqdm.contrib.concurrent import process_map + + +def process_one(inputs: Tuple[str, str, int]) -> None: + infile, outfile, target_fs = inputs + + dir = os.path.dirname(outfile) + os.makedirs(dir, exist_ok=True) + + data, fs = ta.load(infile) + + if fs != target_fs: + data = ta.functional.resample(data, fs, target_fs, resampling_method="sinc_interp_kaiser") + fs = target_fs + + data = data.numpy() + data = data.astype(np.float32) + + if os.path.exists(outfile): + data_ = np.load(outfile) + if np.allclose(data, data_): + return + + np.save(outfile, data) + + +def preprocess( + data_path: str, + output_path: str, + fs: int +) -> None: + files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) + print(files) + outfiles = [ + f.replace(data_path, output_path).replace(".wav", ".npy") for f in + files + ] + + os.makedirs(output_path, exist_ok=True) + inputs = list(zip(files, outfiles, [fs] * len(files))) + + process_map(process_one, inputs, chunksize=32) + + +if __name__ == "__main__": + import fire + + fire.Fire() diff --git a/models/bandit/core/data/musdb/__init__.py b/models/bandit/core/data/musdb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/bandit/core/data/musdb/datamodule.py b/models/bandit/core/data/musdb/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8cda32faa3d3f676893ad8c75bb4719bd4fc31 --- /dev/null +++ b/models/bandit/core/data/musdb/datamodule.py @@ -0,0 +1,77 @@ +import os.path +from typing import Mapping, Optional + +import pytorch_lightning as pl + +from models.bandit.core.data.musdb.dataset import ( + MUSDB18BaseDataset, + MUSDB18FullTrackDataset, + MUSDB18SadDataset, + MUSDB18SadOnTheFlyAugmentedDataset +) + + +def MUSDB18DataModule( + data_root: str = "$DATA_ROOT/MUSDB18/HQ", + target_stem: str = "vocals", + batch_size: int = 2, + num_workers: int = 8, + train_kwargs: Optional[Mapping] = None, + val_kwargs: Optional[Mapping] = None, + test_kwargs: Optional[Mapping] = None, + datamodule_kwargs: Optional[Mapping] = None, + use_on_the_fly: bool = True, + npy_memmap: bool = True +) -> pl.LightningDataModule: + if train_kwargs is None: + train_kwargs = {} + + if val_kwargs is None: + val_kwargs = {} + + if test_kwargs is None: + test_kwargs = {} + + if datamodule_kwargs is None: + datamodule_kwargs = {} + + train_dataset: MUSDB18BaseDataset + + if use_on_the_fly: + train_dataset = MUSDB18SadOnTheFlyAugmentedDataset( + data_root=os.path.join(data_root, "saded-np"), + split="train", + target_stem=target_stem, + **train_kwargs + ) + else: + train_dataset = MUSDB18SadDataset( + data_root=os.path.join(data_root, "saded-np"), + split="train", + target_stem=target_stem, + **train_kwargs + ) + + datamodule = pl.LightningDataModule.from_datasets( + train_dataset=train_dataset, + val_dataset=MUSDB18SadDataset( + data_root=os.path.join(data_root, "saded-np"), + split="val", + target_stem=target_stem, + **val_kwargs + ), + test_dataset=MUSDB18FullTrackDataset( + data_root=os.path.join(data_root, "canonical"), + split="test", + **test_kwargs + ), + batch_size=batch_size, + num_workers=num_workers, + **datamodule_kwargs + ) + + datamodule.predict_dataloader = ( # type: ignore[method-assign] + datamodule.test_dataloader + ) + + return datamodule diff --git a/models/bandit/core/data/musdb/dataset.py b/models/bandit/core/data/musdb/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4374c96995f88dd946801957f95c4c64eeda88f9 --- /dev/null +++ b/models/bandit/core/data/musdb/dataset.py @@ -0,0 +1,280 @@ +import os +from abc import ABC +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torchaudio as ta +from torch.utils import data + +from models.bandit.core.data._types import AudioDict, DataDict +from models.bandit.core.data.base import BaseSourceSeparationDataset + + +class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC): + + ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"] + + def __init__( + self, + split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int = 44100, + npy_memmap=False, + ) -> None: + super().__init__( + split=split, + stems=stems, + files=files, + data_path=data_path, + fs=fs, + npy_memmap=npy_memmap, + recompute_mixture=False + ) + + def get_stem(self, *, stem: str, identifier) -> torch.Tensor: + track = identifier["track"] + path = os.path.join(self.data_path, track) + # noinspection PyUnresolvedReferences + + if self.npy_memmap: + audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r") + else: + audio, _ = ta.load(os.path.join(path, f"{stem}.wav")) + + return audio + + def get_identifier(self, index): + return dict(track=self.files[index]) + + def __getitem__(self, index: int) -> DataDict: + identifier = self.get_identifier(index) + audio = self.get_audio(identifier) + + return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} + + +class MUSDB18FullTrackDataset(MUSDB18BaseDataset): + + N_TRAIN_TRACKS = 100 + N_TEST_TRACKS = 50 + VALIDATION_FILES = [ + "Actions - One Minute Smile", + "Clara Berry And Wooldog - Waltz For My Victims", + "Johnny Lokke - Promises & Lies", + "Patrick Talbot - A Reason To Leave", + "Triviul - Angelsaint", + "Alexander Ross - Goodbye Bolero", + "Fergessen - Nos Palpitants", + "Leaf - Summerghost", + "Skelpolu - Human Mistakes", + "Young Griffo - Pennies", + "ANiMAL - Rockshow", + "James May - On The Line", + "Meaxic - Take A Step", + "Traffic Experiment - Sirens", + ] + + def __init__( + self, data_root: str, split: str, stems: Optional[List[ + str]] = None + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + if split == "test": + subset = "test" + elif split in ["train", "val"]: + subset = "train" + else: + raise NameError + + data_path = os.path.join(data_root, subset) + + files = sorted(os.listdir(data_path)) + files = [f for f in files if not f.startswith(".")] + # pprint(list(enumerate(files))) + if subset == "train": + assert len(files) == 100, len(files) + if split == "train": + files = [f for f in files if f not in self.VALIDATION_FILES] + assert len(files) == 100 - len(self.VALIDATION_FILES) + else: + files = [f for f in files if f in self.VALIDATION_FILES] + assert len(files) == len(self.VALIDATION_FILES) + else: + split = "test" + assert len(files) == 50 + + self.n_tracks = len(files) + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files + ) + + def __len__(self) -> int: + return self.n_tracks + +class MUSDB18SadDataset(MUSDB18BaseDataset): + def __init__( + self, + data_root: str, + split: str, + target_stem: str, + stems: Optional[List[str]] = None, + target_length: Optional[int] = None, + npy_memmap=False, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + + data_path = os.path.join(data_root, target_stem, split) + + files = sorted(os.listdir(data_path)) + files = [f for f in files if not f.startswith(".")] + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + npy_memmap=npy_memmap + ) + self.n_segments = len(files) + self.target_stem = target_stem + self.target_length = ( + target_length if target_length is not None else self.n_segments + ) + + def __len__(self) -> int: + return self.target_length + + def __getitem__(self, index: int) -> DataDict: + + index = index % self.n_segments + + return super().__getitem__(index) + + def get_identifier(self, index): + return super().get_identifier(index % self.n_segments) + + +class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset): + def __init__( + self, + data_root: str, + split: str, + target_stem: str, + stems: Optional[List[str]] = None, + target_length: int = 20000, + apply_probability: Optional[float] = None, + chunk_size_second: float = 3.0, + random_scale_range_db: Tuple[float, float] = (-10, 10), + drop_probability: float = 0.1, + rescale: bool = True, + ) -> None: + super().__init__(data_root, split, target_stem, stems) + + if apply_probability is None: + apply_probability = ( + target_length - self.n_segments) / target_length + + self.apply_probability = apply_probability + self.drop_probability = drop_probability + self.chunk_size_second = chunk_size_second + self.random_scale_range_db = random_scale_range_db + self.rescale = rescale + + self.chunk_size_sample = int(self.chunk_size_second * self.fs) + self.target_length = target_length + + def __len__(self) -> int: + return self.target_length + + def __getitem__(self, index: int) -> DataDict: + + index = index % self.n_segments + + # if np.random.rand() > self.apply_probability: + # return super().__getitem__(index) + + audio = {} + identifier = self.get_identifier(index) + + # assert self.target_stem in self.stems_no_mixture + for stem in self.stems_no_mixture: + if stem == self.target_stem: + identifier_ = identifier + else: + if np.random.rand() < self.apply_probability: + index_ = np.random.randint(self.n_segments) + identifier_ = self.get_identifier(index_) + else: + identifier_ = identifier + + audio[stem] = self.get_stem(stem=stem, identifier=identifier_) + + # if stem == self.target_stem: + + if self.chunk_size_sample < audio[stem].shape[-1]: + chunk_start = np.random.randint( + audio[stem].shape[-1] - self.chunk_size_sample + ) + else: + chunk_start = 0 + + if np.random.rand() < self.drop_probability: + # db_scale = "-inf" + linear_scale = 0.0 + else: + db_scale = np.random.uniform(*self.random_scale_range_db) + linear_scale = np.power(10, db_scale / 20) + # db_scale = f"{db_scale:+2.1f}" + # print(linear_scale) + audio[stem][..., + chunk_start: chunk_start + self.chunk_size_sample] = ( + linear_scale + * audio[stem][..., + chunk_start: chunk_start + self.chunk_size_sample] + ) + + audio["mixture"] = self.compute_mixture(audio) + + if self.rescale: + max_abs_val = max( + [torch.max(torch.abs(audio[stem])) for stem in self.stems] + ) # type: ignore[type-var] + if max_abs_val > 1: + audio = {k: v / max_abs_val for k, v in audio.items()} + + track = identifier["track"] + + return {"audio": audio, "track": f"{self.split}/{track}"} + +# if __name__ == "__main__": +# +# from pprint import pprint +# from tqdm import tqdm +# +# for split_ in ["train", "val", "test"]: +# ds = MUSDB18SadOnTheFlyAugmentedDataset( +# data_root="$DATA_ROOT/MUSDB18/HQ/saded", +# split=split_, +# target_stem="vocals" +# ) +# +# print(split_, len(ds)) +# +# for track_ in tqdm(ds): +# track_["audio"] = { +# k: v.shape for k, v in track_["audio"].items() +# } +# pprint(track_) diff --git a/models/bandit/core/data/musdb/preprocess.py b/models/bandit/core/data/musdb/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2f2a692ab5d1247be59af98199f92afc75710a --- /dev/null +++ b/models/bandit/core/data/musdb/preprocess.py @@ -0,0 +1,238 @@ +import glob +import os + +import numpy as np +import torch +import torchaudio as ta +from torch import nn +from torch.nn import functional as F +from tqdm.contrib.concurrent import process_map + +from core.data._types import DataDict +from core.data.musdb.dataset import MUSDB18FullTrackDataset +import pyloudnorm as pyln + +class SourceActivityDetector(nn.Module): + def __init__( + self, + analysis_stem: str, + output_path: str, + fs: int = 44100, + segment_length_second: float = 6.0, + hop_length_second: float = 3.0, + n_chunks: int = 10, + chunk_epsilon: float = 1e-5, + energy_threshold_quantile: float = 0.15, + segment_epsilon: float = 1e-3, + salient_proportion_threshold: float = 0.5, + target_lufs: float = -24 + ) -> None: + super().__init__() + + self.fs = fs + self.segment_length = int(segment_length_second * self.fs) + self.hop_length = int(hop_length_second * self.fs) + self.n_chunks = n_chunks + assert self.segment_length % self.n_chunks == 0 + self.chunk_size = self.segment_length // self.n_chunks + self.chunk_epsilon = chunk_epsilon + self.energy_threshold_quantile = energy_threshold_quantile + self.segment_epsilon = segment_epsilon + self.salient_proportion_threshold = salient_proportion_threshold + self.analysis_stem = analysis_stem + + self.meter = pyln.Meter(self.fs) + self.target_lufs = target_lufs + + self.output_path = output_path + + def forward(self, data: DataDict) -> None: + + stem_ = self.analysis_stem if ( + self.analysis_stem != "none") else "mixture" + + x = data["audio"][stem_] + + xnp = x.numpy() + loudness = self.meter.integrated_loudness(xnp.T) + + for stem in data["audio"]: + s = data["audio"][stem] + s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T + s = torch.as_tensor(s) + data["audio"][stem] = s + + if x.ndim == 3: + assert x.shape[0] == 1 + x = x[0] + + n_chan, n_samples = x.shape + + n_segments = ( + int( + np.ceil((n_samples - self.segment_length) / self.hop_length) + ) + 1 + ) + + segments = torch.zeros((n_segments, n_chan, self.segment_length)) + for i in range(n_segments): + start = i * self.hop_length + end = start + self.segment_length + end = min(end, n_samples) + + xseg = x[:, start:end] + + if end - start < self.segment_length: + xseg = F.pad( + xseg, + pad=(0, self.segment_length - (end - start)), + value=torch.nan + ) + + segments[i, :, :] = xseg + + chunks = segments.reshape( + (n_segments, n_chan, self.n_chunks, self.chunk_size) + ) + + if self.analysis_stem != "none": + chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3)) + chunk_energies = torch.nan_to_num(chunk_energies, nan=0) + chunk_energies[chunk_energies == 0] = self.chunk_epsilon + + energy_threshold = torch.nanquantile( + chunk_energies, q=self.energy_threshold_quantile + ) + + if energy_threshold < self.segment_epsilon: + energy_threshold = self.segment_epsilon # type: ignore[assignment] + + chunks_above_threshold = chunk_energies > energy_threshold + n_chunks_above_threshold = torch.mean( + chunks_above_threshold.to(torch.float), dim=-1 + ) + + segment_above_threshold = ( + n_chunks_above_threshold > self.salient_proportion_threshold + ) + + if torch.sum(segment_above_threshold) == 0: + return + + else: + segment_above_threshold = torch.ones((n_segments,)) + + for i in range(n_segments): + if not segment_above_threshold[i]: + continue + + outpath = os.path.join( + self.output_path, + self.analysis_stem, + f"{data['track']} - {self.analysis_stem}{i:03d}", + ) + os.makedirs(outpath, exist_ok=True) + + for stem in data["audio"]: + if stem == self.analysis_stem: + segment = torch.nan_to_num(segments[i, :, :], nan=0) + else: + start = i * self.hop_length + end = start + self.segment_length + end = min(n_samples, end) + + segment = data["audio"][stem][:, start:end] + + if end - start < self.segment_length: + segment = F.pad( + segment, + (0, self.segment_length - (end - start)) + ) + + assert segment.shape[-1] == self.segment_length, segment.shape + + # ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs) + + np.save(os.path.join(outpath, f"{stem}.wav"), segment) + + +def preprocess( + analysis_stem: str, + output_path: str = "/data/MUSDB18/HQ/saded-np", + fs: int = 44100, + segment_length_second: float = 6.0, + hop_length_second: float = 3.0, + n_chunks: int = 10, + chunk_epsilon: float = 1e-5, + energy_threshold_quantile: float = 0.15, + segment_epsilon: float = 1e-3, + salient_proportion_threshold: float = 0.5, +) -> None: + + sad = SourceActivityDetector( + analysis_stem=analysis_stem, + output_path=output_path, + fs=fs, + segment_length_second=segment_length_second, + hop_length_second=hop_length_second, + n_chunks=n_chunks, + chunk_epsilon=chunk_epsilon, + energy_threshold_quantile=energy_threshold_quantile, + segment_epsilon=segment_epsilon, + salient_proportion_threshold=salient_proportion_threshold, + ) + + for split in ["train", "val", "test"]: + ds = MUSDB18FullTrackDataset( + data_root="/data/MUSDB18/HQ/canonical", + split=split, + ) + + tracks = [] + for i, track in enumerate(tqdm(ds, total=len(ds))): + if i % 32 == 0 and tracks: + process_map(sad, tracks, max_workers=8) + tracks = [] + tracks.append(track) + process_map(sad, tracks, max_workers=8) + +def loudness_norm_one( + inputs +): + infile, outfile, target_lufs = inputs + + audio, fs = ta.load(infile) + audio = audio.mean(dim=0, keepdim=True).numpy().T + + meter = pyln.Meter(fs) + loudness = meter.integrated_loudness(audio) + audio = pyln.normalize.loudness(audio, loudness, target_lufs) + + os.makedirs(os.path.dirname(outfile), exist_ok=True) + np.save(outfile, audio.T) + +def loudness_norm( + data_path: str, + # output_path: str, + target_lufs = -17.0, +): + files = glob.glob( + os.path.join(data_path, "**", "*.wav"), recursive=True + ) + + outfiles = [ + f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files + ] + + files = [(f, o, target_lufs) for f, o in zip(files, outfiles)] + + process_map(loudness_norm_one, files, chunksize=2) + + + +if __name__ == "__main__": + + from tqdm import tqdm + import fire + + fire.Fire() diff --git a/models/bandit/core/data/musdb/validation.yaml b/models/bandit/core/data/musdb/validation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6bee0bc6f95ca8f15173fe29e7597ce1f4bed3c5 --- /dev/null +++ b/models/bandit/core/data/musdb/validation.yaml @@ -0,0 +1,15 @@ +validation: + - 'Actions - One Minute Smile' + - 'Clara Berry And Wooldog - Waltz For My Victims' + - 'Johnny Lokke - Promises & Lies' + - 'Patrick Talbot - A Reason To Leave' + - 'Triviul - Angelsaint' + - 'Alexander Ross - Goodbye Bolero' + - 'Fergessen - Nos Palpitants' + - 'Leaf - Summerghost' + - 'Skelpolu - Human Mistakes' + - 'Young Griffo - Pennies' + - 'ANiMAL - Rockshow' + - 'James May - On The Line' + - 'Meaxic - Take A Step' + - 'Traffic Experiment - Sirens' \ No newline at end of file diff --git a/models/bandit/core/loss/__init__.py b/models/bandit/core/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6f377effa2450aa481d33776376314e48a0ae7 --- /dev/null +++ b/models/bandit/core/loss/__init__.py @@ -0,0 +1,2 @@ +from ._multistem import MultiStemWrapperFromConfig +from ._timefreq import ReImL1Loss, ReImL2Loss, TimeFreqL1Loss, TimeFreqL2Loss, TimeFreqSignalNoisePNormRatioLoss diff --git a/models/bandit/core/loss/_complex.py b/models/bandit/core/loss/_complex.py new file mode 100644 index 0000000000000000000000000000000000000000..d90e9ba729bc2cb0c27bdb224fcbafc3bb4475cc --- /dev/null +++ b/models/bandit/core/loss/_complex.py @@ -0,0 +1,34 @@ +from typing import Any + +import torch +from torch import nn +from torch.nn.modules import loss as _loss +from torch.nn.modules.loss import _Loss + + +class ReImLossWrapper(_Loss): + def __init__(self, module: _Loss) -> None: + super().__init__() + self.module = module + + def forward( + self, + preds: torch.Tensor, + target: torch.Tensor + ) -> torch.Tensor: + return self.module( + torch.view_as_real(preds), + torch.view_as_real(target) + ) + + +class ReImL1Loss(ReImLossWrapper): + def __init__(self, **kwargs: Any) -> None: + l1_loss = _loss.L1Loss(**kwargs) + super().__init__(module=(l1_loss)) + + +class ReImL2Loss(ReImLossWrapper): + def __init__(self, **kwargs: Any) -> None: + l2_loss = _loss.MSELoss(**kwargs) + super().__init__(module=(l2_loss)) diff --git a/models/bandit/core/loss/_multistem.py b/models/bandit/core/loss/_multistem.py new file mode 100644 index 0000000000000000000000000000000000000000..539f9b8145713abd1d574b312b3277c1e7bd3491 --- /dev/null +++ b/models/bandit/core/loss/_multistem.py @@ -0,0 +1,45 @@ +from typing import Any, Dict + +import torch +from asteroid import losses as asteroid_losses +from torch import nn +from torch.nn.modules.loss import _Loss + +from . import snr + + +def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss: + + for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]: + if name in module.__dict__: + return module.__dict__[name](**kwargs) + + raise NameError + + +class MultiStemWrapper(_Loss): + def __init__(self, module: _Loss, modality: str = "audio") -> None: + super().__init__() + self.loss = module + self.modality = modality + + def forward( + self, + preds: Dict[str, Dict[str, torch.Tensor]], + target: Dict[str, Dict[str, torch.Tensor]], + ) -> torch.Tensor: + loss = { + stem: self.loss( + preds[self.modality][stem], + target[self.modality][stem] + ) + for stem in preds[self.modality] if stem in target[self.modality] + } + + return sum(list(loss.values())) + + +class MultiStemWrapperFromConfig(MultiStemWrapper): + def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None: + loss = parse_loss(name, kwargs) + super().__init__(module=loss, modality=modality) diff --git a/models/bandit/core/loss/_timefreq.py b/models/bandit/core/loss/_timefreq.py new file mode 100644 index 0000000000000000000000000000000000000000..11b4c409beb1e0485828f9a9ea2d89589fd5af8c --- /dev/null +++ b/models/bandit/core/loss/_timefreq.py @@ -0,0 +1,113 @@ +from typing import Any, Dict, Optional + +import torch +from torch import nn +from torch.nn.modules.loss import _Loss + +from models.bandit.core.loss._multistem import MultiStemWrapper +from models.bandit.core.loss._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper +from models.bandit.core.loss.snr import SignalNoisePNormRatio + +class TimeFreqWrapper(_Loss): + def __init__( + self, + time_module: _Loss, + freq_module: Optional[_Loss] = None, + time_weight: float = 1.0, + freq_weight: float = 1.0, + multistem: bool = True, + ) -> None: + super().__init__() + + if freq_module is None: + freq_module = time_module + + if multistem: + time_module = MultiStemWrapper(time_module, modality="audio") + freq_module = MultiStemWrapper(freq_module, modality="spectrogram") + + self.time_module = time_module + self.freq_module = freq_module + + self.time_weight = time_weight + self.freq_weight = freq_weight + + # TODO: add better type hints + def forward(self, preds: Any, target: Any) -> torch.Tensor: + + return self.time_weight * self.time_module( + preds, target + ) + self.freq_weight * self.freq_module(preds, target) + + +class TimeFreqL1Loss(TimeFreqWrapper): + def __init__( + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, + ) -> None: + if tkwargs is None: + tkwargs = {} + if fkwargs is None: + fkwargs = {} + time_module = (nn.L1Loss(**tkwargs)) + freq_module = ReImL1Loss(**fkwargs) + super().__init__( + time_module, + freq_module, + time_weight, + freq_weight, + multistem + ) + + +class TimeFreqL2Loss(TimeFreqWrapper): + def __init__( + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, + ) -> None: + if tkwargs is None: + tkwargs = {} + if fkwargs is None: + fkwargs = {} + time_module = nn.MSELoss(**tkwargs) + freq_module = ReImL2Loss(**fkwargs) + super().__init__( + time_module, + freq_module, + time_weight, + freq_weight, + multistem + ) + + + +class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper): + def __init__( + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, + ) -> None: + if tkwargs is None: + tkwargs = {} + if fkwargs is None: + fkwargs = {} + time_module = SignalNoisePNormRatio(**tkwargs) + freq_module = SignalNoisePNormRatio(**fkwargs) + super().__init__( + time_module, + freq_module, + time_weight, + freq_weight, + multistem + ) diff --git a/models/bandit/core/loss/snr.py b/models/bandit/core/loss/snr.py new file mode 100644 index 0000000000000000000000000000000000000000..d83deb143dceabc1040bdb409044188b46e2dec3 --- /dev/null +++ b/models/bandit/core/loss/snr.py @@ -0,0 +1,146 @@ +import torch +from torch.nn.modules.loss import _Loss +from torch.nn import functional as F + +class SignalNoisePNormRatio(_Loss): + def __init__( + self, + p: float = 1.0, + scale_invariant: bool = False, + zero_mean: bool = False, + take_log: bool = True, + reduction: str = "mean", + EPS: float = 1e-3, + ) -> None: + assert reduction != "sum", NotImplementedError + super().__init__(reduction=reduction) + assert not zero_mean + + self.p = p + + self.EPS = EPS + self.take_log = take_log + + self.scale_invariant = scale_invariant + + def forward( + self, + est_target: torch.Tensor, + target: torch.Tensor + ) -> torch.Tensor: + + target_ = target + if self.scale_invariant: + ndim = target.ndim + dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True) + s_target_energy = ( + torch.sum(target * torch.conj(target), dim=-1, keepdim=True) + ) + + if ndim > 2: + dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True) + s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True) + + target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8) + target = target_ * target_scaler + + if torch.is_complex(est_target): + est_target = torch.view_as_real(est_target) + target = torch.view_as_real(target) + + + batch_size = est_target.shape[0] + est_target = est_target.reshape(batch_size, -1) + target = target.reshape(batch_size, -1) + # target_ = target_.reshape(batch_size, -1) + + if self.p == 1: + e_error = torch.abs(est_target-target).mean(dim=-1) + e_target = torch.abs(target).mean(dim=-1) + elif self.p == 2: + e_error = torch.square(est_target-target).mean(dim=-1) + e_target = torch.square(target).mean(dim=-1) + else: + raise NotImplementedError + + if self.take_log: + loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)) + else: + loss = (e_error + self.EPS)/(e_target + self.EPS) + + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "sum": + loss = loss.sum() + + return loss + + + +class MultichannelSingleSrcNegSDR(_Loss): + def __init__( + self, + sdr_type: str, + p: float = 2.0, + zero_mean: bool = True, + take_log: bool = True, + reduction: str = "mean", + EPS: float = 1e-8, + ) -> None: + assert reduction != "sum", NotImplementedError + super().__init__(reduction=reduction) + + assert sdr_type in ["snr", "sisdr", "sdsdr"] + self.sdr_type = sdr_type + self.zero_mean = zero_mean + self.take_log = take_log + self.EPS = 1e-8 + + self.p = p + + def forward( + self, + est_target: torch.Tensor, + target: torch.Tensor + ) -> torch.Tensor: + if target.size() != est_target.size() or target.ndim != 3: + raise TypeError( + f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead" + ) + # Step 1. Zero-mean norm + if self.zero_mean: + mean_source = torch.mean(target, dim=[1, 2], keepdim=True) + mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True) + target = target - mean_source + est_target = est_target - mean_estimate + # Step 2. Pair-wise SI-SDR. + if self.sdr_type in ["sisdr", "sdsdr"]: + # [batch, 1] + dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True) + # [batch, 1] + s_target_energy = ( + torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS + ) + # [batch, time] + scaled_target = dot * target / s_target_energy + else: + # [batch, time] + scaled_target = target + if self.sdr_type in ["sdsdr", "snr"]: + e_noise = est_target - target + else: + e_noise = est_target - scaled_target + # [batch] + + if self.p == 2.0: + losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / ( + torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS + ) + else: + losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / ( + torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS + ) + if self.take_log: + losses = 10 * torch.log10(losses + self.EPS) + losses = losses.mean() if self.reduction == "mean" else losses + return -losses diff --git a/models/bandit/core/metrics/__init__.py b/models/bandit/core/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9784d2523077a9b31b1d329225b7b0125e5e92ed --- /dev/null +++ b/models/bandit/core/metrics/__init__.py @@ -0,0 +1,9 @@ +from .snr import ( + ChunkMedianScaleInvariantSignalDistortionRatio, + ChunkMedianScaleInvariantSignalNoiseRatio, + ChunkMedianSignalDistortionRatio, + ChunkMedianSignalNoiseRatio, + SafeSignalDistortionRatio, +) + +# from .mushra import EstimatedMushraScore diff --git a/models/bandit/core/metrics/_squim.py b/models/bandit/core/metrics/_squim.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc32cf4861934c3993542ff8d0e94fbe8d51042 --- /dev/null +++ b/models/bandit/core/metrics/_squim.py @@ -0,0 +1,383 @@ +from dataclasses import dataclass + +from torchaudio._internal import load_state_dict_from_url + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def transform_wb_pesq_range(x: float) -> float: + """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined + for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric + defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score". + + Args: + x (float): Narrow-band PESQ score. + + Returns: + (float): Wide-band PESQ score. + """ + return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224)) + + +PESQRange: Tuple[float, float] = ( + 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of + # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound. + # We are using 1.0 as a reasonable approximation. + transform_wb_pesq_range(4.5), +) + + +class RangeSigmoid(nn.Module): + def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None: + super(RangeSigmoid, self).__init__() + assert isinstance(val_range, tuple) and len(val_range) == 2 + self.val_range: Tuple[float, float] = val_range + self.sigmoid: nn.modules.Module = nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0] + return out + + +class Encoder(nn.Module): + """Encoder module that transform 1D waveform to 2D representations. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512) + win_len (int, optional): kernel size in the Conv1D layer. (Default: 32) + """ + + def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None: + super(Encoder, self).__init__() + + self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply waveforms to convolutional layer and ReLU layer. + + Args: + x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. + + Returns: + (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`. + """ + out = x.unsqueeze(dim=1) + out = F.relu(self.conv1d(out)) + return out + + +class SingleRNN(nn.Module): + def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None: + super(SingleRNN, self).__init__() + + self.rnn_type = rnn_type + self.input_size = input_size + self.hidden_size = hidden_size + + self.rnn: nn.modules.Module = getattr(nn, rnn_type)( + input_size, + hidden_size, + 1, + dropout=dropout, + batch_first=True, + bidirectional=True, + ) + + self.proj = nn.Linear(hidden_size * 2, input_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # input shape: batch, seq, dim + out, _ = self.rnn(x) + out = self.proj(out) + return out + + +class DPRNN(nn.Module): + """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64) + hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128) + num_blocks (int, optional): Number of DPRNN layers. (Default: 6) + rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM") + d_model (int, optional): The number of expected features in the input. (Default: 256) + chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100) + chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50) + """ + + def __init__( + self, + feat_dim: int = 64, + hidden_dim: int = 128, + num_blocks: int = 6, + rnn_type: str = "LSTM", + d_model: int = 256, + chunk_size: int = 100, + chunk_stride: int = 50, + ) -> None: + super(DPRNN, self).__init__() + + self.num_blocks = num_blocks + + self.row_rnn = nn.ModuleList([]) + self.col_rnn = nn.ModuleList([]) + self.row_norm = nn.ModuleList([]) + self.col_norm = nn.ModuleList([]) + for _ in range(num_blocks): + self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim)) + self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim)) + self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8)) + self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8)) + self.conv = nn.Sequential( + nn.Conv2d(feat_dim, d_model, 1), + nn.PReLU(), + ) + self.chunk_size = chunk_size + self.chunk_stride = chunk_stride + + def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: + # input shape: (B, N, T) + seq_len = x.shape[-1] + + rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size + out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride]) + + return out, rest + + def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: + out, rest = self.pad_chunk(x) + batch_size, feat_dim, seq_len = out.shape + + segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size) + segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size) + out = torch.cat([segments1, segments2], dim=3) + out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous() + + return out, rest + + def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor: + batch_size, dim, _, _ = x.shape + out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2) + out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :] + out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride] + out = out1 + out2 + if rest > 0: + out = out[:, :, :-rest] + out = out.contiguous() + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, rest = self.chunking(x) + batch_size, _, dim1, dim2 = x.shape + out = x + for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm): + row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous() + row_out = row_rnn(row_in) + row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous() + row_out = row_norm(row_out) + out = out + row_out + + col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous() + col_out = col_rnn(col_in) + col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous() + col_out = col_norm(col_out) + out = out + col_out + out = self.conv(out) + out = self.merging(out, rest) + out = out.transpose(1, 2).contiguous() + return out + + +class AutoPool(nn.Module): + def __init__(self, pool_dim: int = 1) -> None: + super(AutoPool, self).__init__() + self.pool_dim: int = pool_dim + self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim) + self.register_parameter("alpha", nn.Parameter(torch.ones(1))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + weight = self.softmax(torch.mul(x, self.alpha)) + out = torch.sum(torch.mul(x, weight), dim=self.pool_dim) + return out + + +class SquimObjective(nn.Module): + """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores + for speech enhancement (e.g., STOI, PESQ, and SI-SDR). + + Args: + encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation. + dprnn (torch.nn.Module): DPRNN module to model sequential feature. + branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score. + """ + + def __init__( + self, + encoder: nn.Module, + dprnn: nn.Module, + branches: nn.ModuleList, + ): + super(SquimObjective, self).__init__() + self.encoder = encoder + self.dprnn = dprnn + self.branches = branches + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """ + Args: + x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. + + Returns: + List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`. + """ + if x.ndim != 2: + raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.") + x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20) + out = self.encoder(x) + out = self.dprnn(out) + scores = [] + for branch in self.branches: + scores.append(branch(out).squeeze(dim=1)) + return scores + + +def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module: + """Create branch module after DPRNN model for predicting metric score. + + Args: + d_model (int): The number of expected features in the input. + nhead (int): Number of heads in the multi-head attention model. + metric (str): The metric name to predict. + + Returns: + (nn.Module): Returned module to predict corresponding metric score. + """ + layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True) + layer2 = AutoPool() + if metric == "stoi": + layer3 = nn.Sequential( + nn.Linear(d_model, d_model), + nn.PReLU(), + nn.Linear(d_model, 1), + RangeSigmoid(), + ) + elif metric == "pesq": + layer3 = nn.Sequential( + nn.Linear(d_model, d_model), + nn.PReLU(), + nn.Linear(d_model, 1), + RangeSigmoid(val_range=PESQRange), + ) + else: + layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)) + return nn.Sequential(layer1, layer2, layer3) + + +def squim_objective_model( + feat_dim: int, + win_len: int, + d_model: int, + nhead: int, + hidden_dim: int, + num_blocks: int, + rnn_type: str, + chunk_size: int, + chunk_stride: Optional[int] = None, +) -> SquimObjective: + """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. + win_len (int): Kernel size in the Encoder module. + d_model (int): The number of expected features in the input. + nhead (int): Number of heads in the multi-head attention model. + hidden_dim (int): Hidden dimension in the RNN layer of DPRNN. + num_blocks (int): Number of DPRNN layers. + rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. + chunk_size (int): Chunk size of input for DPRNN. + chunk_stride (int or None, optional): Stride of chunk input for DPRNN. + """ + if chunk_stride is None: + chunk_stride = chunk_size // 2 + encoder = Encoder(feat_dim, win_len) + dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride) + branches = nn.ModuleList( + [ + _create_branch(d_model, nhead, "stoi"), + _create_branch(d_model, nhead, "pesq"), + _create_branch(d_model, nhead, "sisdr"), + ] + ) + return SquimObjective(encoder, dprnn, branches) + + +def squim_objective_base() -> SquimObjective: + """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments.""" + return squim_objective_model( + feat_dim=256, + win_len=64, + d_model=256, + nhead=4, + hidden_dim=256, + num_blocks=2, + rnn_type="LSTM", + chunk_size=71, + ) + +@dataclass +class SquimObjectiveBundle: + + _path: str + _sample_rate: float + + def _get_state_dict(self, dl_kwargs): + url = f"https://download.pytorch.org/torchaudio/models/{self._path}" + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(url, **dl_kwargs) + return state_dict + + def get_model(self, *, dl_kwargs=None) -> SquimObjective: + """Construct the SquimObjective model, and load the pretrained weight. + + The weight file is downloaded from the internet and cached with + :func:`torch.hub.load_state_dict_from_url` + + Args: + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. + + Returns: + Variation of :py:class:`~torchaudio.models.SquimObjective`. + """ + model = squim_objective_base() + model.load_state_dict(self._get_state_dict(dl_kwargs)) + model.eval() + return model + + @property + def sample_rate(self): + """Sample rate of the audio that the model is trained on. + + :type: float + """ + return self._sample_rate + + +SQUIM_OBJECTIVE = SquimObjectiveBundle( + "squim_objective_dns2020.pth", + _sample_rate=16000, +) +SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in + :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`. + + The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`. + The weights are under `Creative Commons Attribution 4.0 International License + `__. + + Please refer to :py:class:`SquimObjectiveBundle` for usage instructions. + """ + diff --git a/models/bandit/core/metrics/snr.py b/models/bandit/core/metrics/snr.py new file mode 100644 index 0000000000000000000000000000000000000000..c8fa70a0d31834c6c405c72027c11edc9c8dda9f --- /dev/null +++ b/models/bandit/core/metrics/snr.py @@ -0,0 +1,150 @@ +from typing import Any, Callable + +import numpy as np +import torch +import torchmetrics as tm +from torch._C import _LinAlgError +from torchmetrics import functional as tmF + + +class SafeSignalDistortionRatio(tm.SignalDistortionRatio): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def update(self, *args, **kwargs) -> Any: + try: + super().update(*args, **kwargs) + except: + pass + + def compute(self) -> Any: + if self.total == 0: + return torch.tensor(torch.nan) + return super().compute() + + +class BaseChunkMedianSignalRatio(tm.Metric): + def __init__( + self, + func: Callable, + window_size: int, + hop_size: int = None, + zero_mean: bool = False, + ) -> None: + super().__init__() + + # self.zero_mean = zero_mean + self.func = func + self.window_size = window_size + if hop_size is None: + hop_size = window_size + self.hop_size = hop_size + + self.add_state( + "sum_snr", + default=torch.tensor(0.0), + dist_reduce_fx="sum" + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + + n_samples = target.shape[-1] + + n_chunks = int( + np.ceil((n_samples - self.window_size) / self.hop_size) + 1 + ) + + snr_chunk = [] + + for i in range(n_chunks): + start = i * self.hop_size + + if n_samples - start < self.window_size: + continue + + end = start + self.window_size + + try: + chunk_snr = self.func( + preds[..., start:end], + target[..., start:end] + ) + + # print(preds.shape, chunk_snr.shape) + + if torch.all(torch.isfinite(chunk_snr)): + snr_chunk.append(chunk_snr) + except _LinAlgError: + pass + + snr_chunk = torch.stack(snr_chunk, dim=-1) + snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1) + + self.sum_snr += snr_batch.sum() + self.total += snr_batch.numel() + + def compute(self) -> Any: + return self.sum_snr / self.total + + +class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio): + def __init__( + self, + window_size: int, + hop_size: int = None, + zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.signal_noise_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) + + +class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio): + def __init__( + self, + window_size: int, + hop_size: int = None, + zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.scale_invariant_signal_noise_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) + + +class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio): + def __init__( + self, + window_size: int, + hop_size: int = None, + zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.signal_distortion_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) + + +class ChunkMedianScaleInvariantSignalDistortionRatio( + BaseChunkMedianSignalRatio + ): + def __init__( + self, + window_size: int, + hop_size: int = None, + zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.scale_invariant_signal_distortion_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) diff --git a/models/bandit/core/model/__init__.py b/models/bandit/core/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f05f064e3ec7a5b1b2205347c88e482acd5f00a --- /dev/null +++ b/models/bandit/core/model/__init__.py @@ -0,0 +1,3 @@ +from .bsrnn.wrapper import ( + MultiMaskMultiSourceBandSplitRNNSimple, +) diff --git a/models/bandit/core/model/_spectral.py b/models/bandit/core/model/_spectral.py new file mode 100644 index 0000000000000000000000000000000000000000..5935984681b98ac02306a34359029a823bbd53b5 --- /dev/null +++ b/models/bandit/core/model/_spectral.py @@ -0,0 +1,58 @@ +from typing import Dict, Optional + +import torch +import torchaudio as ta +from torch import nn + + +class _SpectralComponent(nn.Module): + def __init__( + self, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + **kwargs, + ) -> None: + super().__init__() + + assert power is None + + window_fn = torch.__dict__[window_fn] + + self.stft = ( + ta.transforms.Spectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + normalized=normalized, + center=center, + onesided=onesided, + ) + ) + + self.istft = ( + ta.transforms.InverseSpectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + normalized=normalized, + center=center, + onesided=onesided, + ) + ) diff --git a/models/bandit/core/model/bsrnn/__init__.py b/models/bandit/core/model/bsrnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f40b6bdab53844943de5e45e45e23a230a1967a9 --- /dev/null +++ b/models/bandit/core/model/bsrnn/__init__.py @@ -0,0 +1,23 @@ +from abc import ABC +from typing import Iterable, Mapping, Union + +from torch import nn + +from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule +from models.bandit.core.model.bsrnn.tfmodel import ( + SeqBandModellingModule, + TransformerTimeFreqModule, +) + + +class BandsplitCoreBase(nn.Module, ABC): + band_split: nn.Module + tf_model: nn.Module + mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]] + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def mask(x, m): + return x * m diff --git a/models/bandit/core/model/bsrnn/bandsplit.py b/models/bandit/core/model/bsrnn/bandsplit.py new file mode 100644 index 0000000000000000000000000000000000000000..d16a2328d0219420373c153e57b97dfd4082c63d --- /dev/null +++ b/models/bandit/core/model/bsrnn/bandsplit.py @@ -0,0 +1,139 @@ +from typing import List, Tuple + +import torch +from torch import nn + +from models.bandit.core.model.bsrnn.utils import ( + band_widths_from_specs, + check_no_gap, + check_no_overlap, + check_nonzero_bandwidth, +) + + +class NormFC(nn.Module): + def __init__( + self, + emb_dim: int, + bandwidth: int, + in_channel: int, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + ) -> None: + super().__init__() + + self.treat_channel_as_feature = treat_channel_as_feature + + if normalize_channel_independently: + raise NotImplementedError + + reim = 2 + + self.norm = nn.LayerNorm(in_channel * bandwidth * reim) + + fc_in = bandwidth * reim + + if treat_channel_as_feature: + fc_in *= in_channel + else: + assert emb_dim % in_channel == 0 + emb_dim = emb_dim // in_channel + + self.fc = nn.Linear(fc_in, emb_dim) + + def forward(self, xb): + # xb = (batch, n_time, in_chan, reim * band_width) + + batch, n_time, in_chan, ribw = xb.shape + xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw)) + # (batch, n_time, in_chan * reim * band_width) + + if not self.treat_channel_as_feature: + xb = xb.reshape(batch, n_time, in_chan, ribw) + # (batch, n_time, in_chan, reim * band_width) + + zb = self.fc(xb) + # (batch, n_time, emb_dim) + # OR + # (batch, n_time, in_chan, emb_dim_per_chan) + + if not self.treat_channel_as_feature: + batch, n_time, in_chan, emb_dim_per_chan = zb.shape + # (batch, n_time, in_chan, emb_dim_per_chan) + zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan)) + + return zb # (batch, n_time, emb_dim) + + +class BandSplitModule(nn.Module): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + in_channel: int, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + ) -> None: + super().__init__() + + check_nonzero_bandwidth(band_specs) + + if require_no_gap: + check_no_gap(band_specs) + + if require_no_overlap: + check_no_overlap(band_specs) + + self.band_specs = band_specs + # list of [fstart, fend) in index. + # Note that fend is exclusive. + self.band_widths = band_widths_from_specs(band_specs) + self.n_bands = len(band_specs) + self.emb_dim = emb_dim + + self.norm_fc_modules = nn.ModuleList( + [ # type: ignore + ( + NormFC( + emb_dim=emb_dim, + bandwidth=bw, + in_channel=in_channel, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + ) + ) + for bw in self.band_widths + ] + ) + + def forward(self, x: torch.Tensor): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + + batch, in_chan, _, n_time = x.shape + + z = torch.zeros( + size=(batch, self.n_bands, n_time, self.emb_dim), + device=x.device + ) + + xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2 + xr = torch.permute( + xr, + (0, 3, 1, 4, 2) + ) # batch, n_time, in_chan, 2, n_freq + batch, n_time, in_chan, reim, band_width = xr.shape + for i, nfm in enumerate(self.norm_fc_modules): + # print(f"bandsplit/band{i:02d}") + fstart, fend = self.band_specs[i] + xb = xr[..., fstart:fend] + # (batch, n_time, in_chan, reim, band_width) + xb = torch.reshape(xb, (batch, n_time, in_chan, -1)) + # (batch, n_time, in_chan, reim * band_width) + # z.append(nfm(xb)) # (batch, n_time, emb_dim) + z[:, i, :, :] = nfm(xb.contiguous()) + + # z = torch.stack(z, dim=1) + + return z diff --git a/models/bandit/core/model/bsrnn/core.py b/models/bandit/core/model/bsrnn/core.py new file mode 100644 index 0000000000000000000000000000000000000000..e90afd3fa39fb034128af0c3cc1c125839e8fa9f --- /dev/null +++ b/models/bandit/core/model/bsrnn/core.py @@ -0,0 +1,661 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from models.bandit.core.model.bsrnn import BandsplitCoreBase +from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule +from models.bandit.core.model.bsrnn.maskestim import ( + MaskEstimationModule, + OverlappingMaskEstimationModule +) +from models.bandit.core.model.bsrnn.tfmodel import ( + ConvolutionalTimeFreqModule, + SeqBandModellingModule, + TransformerTimeFreqModule +) + + +class MultiMaskBandSplitCoreBase(BandsplitCoreBase): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, cond=None, compute_residual: bool = True): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + # print(x.shape) + batch, in_chan, n_freq, n_time = x.shape + x = torch.reshape(x, (-1, 1, n_freq, n_time)) + + z = self.band_split(x) # (batch, emb_dim, n_band, n_time) + + # if torch.any(torch.isnan(z)): + # raise ValueError("z nan") + + # print(z) + q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) + # print(q) + + + # if torch.any(torch.isnan(q)): + # raise ValueError("q nan") + + out = {} + + for stem, mem in self.mask_estim.items(): + m = mem(q, cond=cond) + + # if torch.any(torch.isnan(m)): + # raise ValueError("m nan", stem) + + s = self.mask(x, m) + s = torch.reshape(s, (batch, in_chan, n_freq, n_time)) + out[stem] = s + + return {"spectrogram": out} + + + + def instantiate_mask_estim(self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + cond_dim: int, + hidden_activation: str, + + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + mult_add_mask: bool = False + ): + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if "mne:+" in stems: + stems = [s for s in stems if s != "mne:+"] + + if overlapping_band: + assert freq_weights is not None + assert n_freq is not None + + if mult_add_mask: + + self.mask_estim = nn.ModuleDict( + { + stem: MultAddMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + use_freq_weights=use_freq_weights, + ) + for stem in stems + } + ) + else: + self.mask_estim = nn.ModuleDict( + { + stem: OverlappingMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + use_freq_weights=use_freq_weights, + ) + for stem in stems + } + ) + else: + self.mask_estim = nn.ModuleDict( + { + stem: MaskEstimationModule( + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for stem in stems + } + ) + + def instantiate_bandsplit(self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + emb_dim: int = 128 + ): + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + +class SingleMaskBandsplitCoreBase(BandsplitCoreBase): + def __init__(self, **kwargs) -> None: + super().__init__() + + def forward(self, x): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + z = self.band_split(x) # (batch, emb_dim, n_band, n_time) + q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) + m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time) + + s = self.mask(x, m) + + return s + + +class SingleMaskBandsplitCoreRNN( + SingleMaskBandsplitCoreBase, +): + def __init__( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + ) -> None: + super().__init__() + self.band_split = (BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + )) + self.tf_model = (SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + )) + self.mask_estim = (MaskEstimationModule( + in_channel=in_channel, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + )) + + +class SingleMaskBandsplitCoreTransformer( + SingleMaskBandsplitCoreBase, +): + def __init__( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + ) -> None: + super().__init__() + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + self.tf_model = TransformerTimeFreqModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, + ) + self.mask_estim = MaskEstimationModule( + in_channel=in_channel, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + +class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + cond_dim: int = 0, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + mult_add_mask: bool = False + ) -> None: + + super().__init__() + self.instantiate_bandsplit( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim + ) + + + self.tf_model = ( + SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + ) + + self.mult_add_mask = mult_add_mask + + self.instantiate_mask_estim( + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + @staticmethod + def _mult_add_mask(x, m): + + assert m.ndim == 5 + + mm = m[..., 0] + am = m[..., 1] + + # print(mm.shape, am.shape, x.shape, m.shape) + + return x * mm + am + + def mask(self, x, m): + if self.mult_add_mask: + + return self._mult_add_mask(x, m) + else: + return super().mask(x, m) + + +class MultiSourceMultiMaskBandSplitCoreTransformer( + MultiMaskBandSplitCoreBase, +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights:bool=True, + rnn_type: str = "LSTM", + cond_dim: int = 0, + mult_add_mask: bool = False + ) -> None: + super().__init__() + self.instantiate_bandsplit( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim + ) + self.tf_model = TransformerTimeFreqModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, + ) + + self.instantiate_mask_estim( + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + + +class MultiSourceMultiMaskBandSplitCoreConv( + MultiMaskBandSplitCoreBase, +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights:bool=True, + rnn_type: str = "LSTM", + cond_dim: int = 0, + mult_add_mask: bool = False + ) -> None: + super().__init__() + self.instantiate_bandsplit( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim + ) + self.tf_model = ConvolutionalTimeFreqModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, + ) + + self.instantiate_mask_estim( + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + +class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase): + def __init__(self) -> None: + super().__init__() + + def mask(self, x, m): + # x.shape = (batch, n_channel, n_freq, n_time) + # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time) + + _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape + padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2) + + xf = F.unfold( + x, + kernel_size=(kernel_freq, kernel_time), + padding=padding, + stride=(1, 1), + ) + + xf = xf.view( + -1, + n_channel, + kernel_freq, + kernel_time, + n_freq, + n_time, + ) + + sf = xf * m + + sf = sf.view( + -1, + n_channel * kernel_freq * kernel_time, + n_freq * n_time, + ) + + s = F.fold( + sf, + output_size=(n_freq, n_time), + kernel_size=(kernel_freq, kernel_time), + padding=padding, + stride=(1, 1), + ).view( + -1, + n_channel, + n_freq, + n_time, + ) + + return s + + def old_mask(self, x, m): + # x.shape = (batch, n_channel, n_freq, n_time) + # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time) + + s = torch.zeros_like(x) + + _, n_channel, n_freq, n_time = x.shape + kernel_freq, kernel_time, _, _, _, _ = m.shape + + # print(x.shape, m.shape) + + kernel_freq_half = (kernel_freq - 1) // 2 + kernel_time_half = (kernel_time - 1) // 2 + + for ifreq in range(kernel_freq): + for itime in range(kernel_time): + df, dt = kernel_freq_half - ifreq, kernel_time_half - itime + x = x.roll(shifts=(df, dt), dims=(2, 3)) + + # if `df` > 0: + # x[:, :, :df, :] = 0 + # elif `df` < 0: + # x[:, :, df:, :] = 0 + + # if `dt` > 0: + # x[:, :, :, :dt] = 0 + # elif `dt` < 0: + # x[:, :, :, dt:] = 0 + + fslice = slice(max(0, df), min(n_freq, n_freq + df)) + tslice = slice(max(0, dt), min(n_time, n_time + dt)) + + s[:, :, fslice, tslice] += x[:, :, fslice, tslice] * m[ifreq, + itime, :, + :, fslice, + tslice] + + return s + + +class MultiSourceMultiPatchingMaskBandSplitCoreRNN( + PatchingMaskBandsplitCoreBase +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + mask_kernel_freq: int, + mask_kernel_time: int, + conv_kernel_freq: int, + conv_kernel_time: int, + kernel_norm_mlp_version: int, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + ) -> None: + + super().__init__() + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + + self.tf_model = ( + SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + ) + + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if overlapping_band: + assert freq_weights is not None + assert n_freq is not None + self.mask_estim = nn.ModuleDict( + { + stem: PatchingMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + mask_kernel_freq=mask_kernel_freq, + mask_kernel_time=mask_kernel_time, + conv_kernel_freq=conv_kernel_freq, + conv_kernel_time=conv_kernel_time, + kernel_norm_mlp_version=kernel_norm_mlp_version + ) + for stem in stems + } + ) + else: + raise NotImplementedError diff --git a/models/bandit/core/model/bsrnn/maskestim.py b/models/bandit/core/model/bsrnn/maskestim.py new file mode 100644 index 0000000000000000000000000000000000000000..699c3ce631ace52edcbed78396d3058ca62cbef3 --- /dev/null +++ b/models/bandit/core/model/bsrnn/maskestim.py @@ -0,0 +1,347 @@ +import warnings +from typing import Dict, List, Optional, Tuple, Type + +import torch +from torch import nn +from torch.nn.modules import activation + +from models.bandit.core.model.bsrnn.utils import ( + band_widths_from_specs, + check_no_gap, + check_no_overlap, + check_nonzero_bandwidth, +) + + +class BaseNormMLP(nn.Module): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, ): + + super().__init__() + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + self.hidden_activation_kwargs = hidden_activation_kwargs + self.norm = nn.LayerNorm(emb_dim) + self.hidden = torch.jit.script(nn.Sequential( + nn.Linear(in_features=emb_dim, out_features=mlp_dim), + activation.__dict__[hidden_activation]( + **self.hidden_activation_kwargs + ), + )) + + self.bandwidth = bandwidth + self.in_channel = in_channel + + self.complex_mask = complex_mask + self.reim = 2 if complex_mask else 1 + self.glu_mult = 2 + + +class NormMLP(BaseNormMLP): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ) -> None: + super().__init__( + emb_dim=emb_dim, + mlp_dim=mlp_dim, + bandwidth=bandwidth, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + self.output = torch.jit.script( + nn.Sequential( + nn.Linear( + in_features=mlp_dim, + out_features=bandwidth * in_channel * self.reim * 2, + ), + nn.GLU(dim=-1), + ) + ) + + def reshape_output(self, mb): + # print(mb.shape) + batch, n_time, _ = mb.shape + if self.complex_mask: + mb = mb.reshape( + batch, + n_time, + self.in_channel, + self.bandwidth, + self.reim + ).contiguous() + # print(mb.shape) + mb = torch.view_as_complex( + mb + ) # (batch, n_time, in_channel, bandwidth) + else: + mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth) + + mb = torch.permute( + mb, + (0, 2, 3, 1) + ) # (batch, in_channel, bandwidth, n_time) + + return mb + + def forward(self, qb): + # qb = (batch, n_time, emb_dim) + + # if torch.any(torch.isnan(qb)): + # raise ValueError("qb0") + + + qb = self.norm(qb) # (batch, n_time, emb_dim) + + # if torch.any(torch.isnan(qb)): + # raise ValueError("qb1") + + qb = self.hidden(qb) # (batch, n_time, mlp_dim) + # if torch.any(torch.isnan(qb)): + # raise ValueError("qb2") + mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim) + # if torch.any(torch.isnan(qb)): + # raise ValueError("mb") + mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time) + + return mb + + +class MultAddNormMLP(NormMLP): + def __init__(self, emb_dim: int, mlp_dim: int, bandwidth: int, in_channel: "int | None", hidden_activation: str = "Tanh", hidden_activation_kwargs=None, complex_mask: bool = True) -> None: + super().__init__(emb_dim, mlp_dim, bandwidth, in_channel, hidden_activation, hidden_activation_kwargs, complex_mask) + + self.output2 = torch.jit.script( + nn.Sequential( + nn.Linear( + in_features=mlp_dim, + out_features=bandwidth * in_channel * self.reim * 2, + ), + nn.GLU(dim=-1), + ) + ) + + def forward(self, qb): + + qb = self.norm(qb) # (batch, n_time, emb_dim) + qb = self.hidden(qb) # (batch, n_time, mlp_dim) + mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim) + mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time) + amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim) + amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time) + + return mmb, amb + + +class MaskEstimationModuleSuperBase(nn.Module): + pass + + +class MaskEstimationModuleBase(MaskEstimationModuleSuperBase): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + ) -> None: + super().__init__() + + self.band_widths = band_widths_from_specs(band_specs) + self.n_bands = len(band_specs) + + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if norm_mlp_kwargs is None: + norm_mlp_kwargs = {} + + self.norm_mlp = nn.ModuleList( + [ + ( + norm_mlp_cls( + bandwidth=self.band_widths[b], + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + **norm_mlp_kwargs, + ) + ) + for b in range(self.n_bands) + ] + ) + + def compute_masks(self, q): + batch, n_bands, n_time, emb_dim = q.shape + + masks = [] + + for b, nmlp in enumerate(self.norm_mlp): + # print(f"maskestim/{b:02d}") + qb = q[:, b, :, :] + mb = nmlp(qb) + masks.append(mb) + + return masks + + + +class OverlappingMaskEstimationModule(MaskEstimationModuleBase): + def __init__( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + freq_weights: List[torch.Tensor], + n_freq: int, + emb_dim: int, + mlp_dim: int, + cond_dim: int = 0, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + use_freq_weights: bool = True, + ) -> None: + check_nonzero_bandwidth(band_specs) + check_no_gap(band_specs) + + # if cond_dim > 0: + # raise NotImplementedError + + super().__init__( + band_specs=band_specs, + emb_dim=emb_dim + cond_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + norm_mlp_cls=norm_mlp_cls, + norm_mlp_kwargs=norm_mlp_kwargs, + ) + + self.n_freq = n_freq + self.band_specs = band_specs + self.in_channel = in_channel + + if freq_weights is not None: + for i, fw in enumerate(freq_weights): + self.register_buffer(f"freq_weights/{i}", fw) + + self.use_freq_weights = use_freq_weights + else: + self.use_freq_weights = False + + self.cond_dim = cond_dim + + def forward(self, q, cond=None): + # q = (batch, n_bands, n_time, emb_dim) + + batch, n_bands, n_time, emb_dim = q.shape + + if cond is not None: + print(cond) + if cond.ndim == 2: + cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1) + elif cond.ndim == 3: + assert cond.shape[1] == n_time + else: + raise ValueError(f"Invalid cond shape: {cond.shape}") + + q = torch.cat([q, cond], dim=-1) + elif self.cond_dim > 0: + cond = torch.ones( + (batch, n_bands, n_time, self.cond_dim), + device=q.device, + dtype=q.dtype, + ) + q = torch.cat([q, cond], dim=-1) + else: + pass + + mask_list = self.compute_masks( + q + ) # [n_bands * (batch, in_channel, bandwidth, n_time)] + + masks = torch.zeros( + (batch, self.in_channel, self.n_freq, n_time), + device=q.device, + dtype=mask_list[0].dtype, + ) + + for im, mask in enumerate(mask_list): + fstart, fend = self.band_specs[im] + if self.use_freq_weights: + fw = self.get_buffer(f"freq_weights/{im}")[:, None] + mask = mask * fw + masks[:, :, fstart:fend, :] += mask + + return masks + + +class MaskEstimationModule(OverlappingMaskEstimationModule): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + **kwargs, + ) -> None: + check_nonzero_bandwidth(band_specs) + check_no_gap(band_specs) + check_no_overlap(band_specs) + super().__init__( + in_channel=in_channel, + band_specs=band_specs, + freq_weights=None, + n_freq=None, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + def forward(self, q, cond=None): + # q = (batch, n_bands, n_time, emb_dim) + + masks = self.compute_masks( + q + ) # [n_bands * (batch, in_channel, bandwidth, n_time)] + + # TODO: currently this requires band specs to have no gap and no overlap + masks = torch.concat( + masks, + dim=2 + ) # (batch, in_channel, n_freq, n_time) + + return masks diff --git a/models/bandit/core/model/bsrnn/tfmodel.py b/models/bandit/core/model/bsrnn/tfmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..2cbbcb397db2fc86d6e3330db5e258a40aa0517b --- /dev/null +++ b/models/bandit/core/model/bsrnn/tfmodel.py @@ -0,0 +1,317 @@ +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules import rnn + +import torch.backends.cuda + + +class TimeFrequencyModellingModule(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class ResidualRNN(nn.Module): + def __init__( + self, + emb_dim: int, + rnn_dim: int, + bidirectional: bool = True, + rnn_type: str = "LSTM", + use_batch_trick: bool = True, + use_layer_norm: bool = True, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + + self.use_layer_norm = use_layer_norm + if use_layer_norm: + self.norm = nn.LayerNorm(emb_dim) + else: + self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim) + + self.rnn = rnn.__dict__[rnn_type]( + input_size=emb_dim, + hidden_size=rnn_dim, + num_layers=1, + batch_first=True, + bidirectional=bidirectional, + ) + + self.fc = nn.Linear( + in_features=rnn_dim * (2 if bidirectional else 1), + out_features=emb_dim + ) + + self.use_batch_trick = use_batch_trick + if not self.use_batch_trick: + warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!") + + def forward(self, z): + # z = (batch, n_uncrossed, n_across, emb_dim) + + z0 = torch.clone(z) + + # print(z.device) + + if self.use_layer_norm: + z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim) + else: + z = torch.permute( + z, (0, 3, 1, 2) + ) # (batch, emb_dim, n_uncrossed, n_across) + + z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across) + + z = torch.permute( + z, (0, 2, 3, 1) + ) # (batch, n_uncrossed, n_across, emb_dim) + + batch, n_uncrossed, n_across, emb_dim = z.shape + + if self.use_batch_trick: + z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim)) + + z = self.rnn(z.contiguous())[0] # (batch * n_uncrossed, n_across, dir_rnn_dim) + + z = torch.reshape(z, (batch, n_uncrossed, n_across, -1)) + # (batch, n_uncrossed, n_across, dir_rnn_dim) + else: + # Note: this is EXTREMELY SLOW + zlist = [] + for i in range(n_uncrossed): + zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim) + zlist.append(zi) + + z = torch.stack( + zlist, + dim=1 + ) # (batch, n_uncrossed, n_across, dir_rnn_dim) + + z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim) + + z = z + z0 + + return z + + +class SeqBandModellingModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + parallel_mode=False, + ) -> None: + super().__init__() + self.seqband = nn.ModuleList([]) + + if parallel_mode: + for _ in range(n_modules): + self.seqband.append( + nn.ModuleList( + [ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + )] + ) + ) + else: + + for _ in range(2 * n_modules): + self.seqband.append( + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + ) + + self.parallel_mode = parallel_mode + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + + if self.parallel_mode: + for sbm_pair in self.seqband: + # z: (batch, n_bands, n_time, emb_dim) + sbm_t, sbm_f = sbm_pair[0], sbm_pair[1] + zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim) + zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim) + z = zt + zf.transpose(1, 2) + else: + for sbm in self.seqband: + z = sbm(z) + z = z.transpose(1, 2) + + # (batch, n_bands, n_time, emb_dim) + # --> (batch, n_time, n_bands, emb_dim) + # OR + # (batch, n_time, n_bands, emb_dim) + # --> (batch, n_bands, n_time, emb_dim) + + q = z + return q # (batch, n_bands, n_time, emb_dim) + + +class ResidualTransformer(nn.Module): + def __init__( + self, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + + self.tf = nn.TransformerEncoderLayer( + d_model=emb_dim, + nhead=4, + dim_feedforward=rnn_dim, + batch_first=True + ) + + self.is_causal = not bidirectional + self.dropout = dropout + + def forward(self, z): + batch, n_uncrossed, n_across, emb_dim = z.shape + z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim)) + z = self.tf(z, is_causal=self.is_causal) # (batch, n_uncrossed, n_across, emb_dim) + z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim)) + + return z + + +class TransformerTimeFreqModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.norm = nn.LayerNorm(emb_dim) + self.seqband = nn.ModuleList([]) + + for _ in range(2 * n_modules): + self.seqband.append( + ResidualTransformer( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=dropout, + ) + ) + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + z = self.norm(z) # (batch, n_bands, n_time, emb_dim) + + for sbm in self.seqband: + z = sbm(z) + z = z.transpose(1, 2) + + # (batch, n_bands, n_time, emb_dim) + # --> (batch, n_time, n_bands, emb_dim) + # OR + # (batch, n_time, n_bands, emb_dim) + # --> (batch, n_bands, n_time, emb_dim) + + q = z + return q # (batch, n_bands, n_time, emb_dim) + + + +class ResidualConvolution(nn.Module): + def __init__( + self, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + self.norm = nn.InstanceNorm2d(emb_dim, affine=True) + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=emb_dim, + out_channels=rnn_dim, + kernel_size=(3, 3), + padding="same", + stride=(1, 1), + ), + nn.Tanhshrink() + ) + + self.is_causal = not bidirectional + self.dropout = dropout + + self.fc = nn.Conv2d( + in_channels=rnn_dim, + out_channels=emb_dim, + kernel_size=(1, 1), + padding="same", + stride=(1, 1), + ) + + + def forward(self, z): + # z = (batch, n_uncrossed, n_across, emb_dim) + + z0 = torch.clone(z) + + z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim) + z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim) + z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim) + z = z + z0 + + return z + + +class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.seqband = torch.jit.script(nn.Sequential( + *[ResidualConvolution( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=dropout, + ) for _ in range(2 * n_modules) ])) + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + + z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time) + + z = self.seqband(z) # (batch, emb_dim, n_bands, n_time) + + z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim) + + return z diff --git a/models/bandit/core/model/bsrnn/utils.py b/models/bandit/core/model/bsrnn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..505ca987db10274f4ff1cae568cf551ecdf96a9f --- /dev/null +++ b/models/bandit/core/model/bsrnn/utils.py @@ -0,0 +1,583 @@ +import os +from abc import abstractmethod +from typing import Any, Callable + +import numpy as np +import torch +from librosa import hz_to_midi, midi_to_hz +from torch import Tensor +from torchaudio import functional as taF +from spafe.fbanks import bark_fbanks +from spafe.utils.converters import erb2hz, hz2bark, hz2erb +from torchaudio.functional.functional import _create_triangular_filterbank + + +def band_widths_from_specs(band_specs): + return [e - i for i, e in band_specs] + + +def check_nonzero_bandwidth(band_specs): + # pprint(band_specs) + for fstart, fend in band_specs: + if fend - fstart <= 0: + raise ValueError("Bands cannot be zero-width") + + +def check_no_overlap(band_specs): + fend_prev = -1 + for fstart_curr, fend_curr in band_specs: + if fstart_curr <= fend_prev: + raise ValueError("Bands cannot overlap") + + +def check_no_gap(band_specs): + fstart, _ = band_specs[0] + assert fstart == 0 + + fend_prev = -1 + for fstart_curr, fend_curr in band_specs: + if fstart_curr - fend_prev > 1: + raise ValueError("Bands cannot leave gap") + fend_prev = fend_curr + + +class BandsplitSpecification: + def __init__(self, nfft: int, fs: int) -> None: + self.fs = fs + self.nfft = nfft + self.nyquist = fs / 2 + self.max_index = nfft // 2 + 1 + + self.split500 = self.hertz_to_index(500) + self.split1k = self.hertz_to_index(1000) + self.split2k = self.hertz_to_index(2000) + self.split4k = self.hertz_to_index(4000) + self.split8k = self.hertz_to_index(8000) + self.split16k = self.hertz_to_index(16000) + self.split20k = self.hertz_to_index(20000) + + self.above20k = [(self.split20k, self.max_index)] + self.above16k = [(self.split16k, self.split20k)] + self.above20k + + def index_to_hertz(self, index: int): + return index * self.fs / self.nfft + + def hertz_to_index(self, hz: float, round: bool = True): + index = hz * self.nfft / self.fs + + if round: + index = int(np.round(index)) + + return index + + def get_band_specs_with_bandwidth( + self, + start_index, + end_index, + bandwidth_hz + ): + band_specs = [] + lower = start_index + + while lower < end_index: + upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz))) + upper = min(upper, end_index) + + band_specs.append((lower, upper)) + lower = upper + + return band_specs + + @abstractmethod + def get_band_specs(self): + raise NotImplementedError + + +class VocalBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int, version: str = "7") -> None: + super().__init__(nfft=nfft, fs=fs) + + self.version = version + + def get_band_specs(self): + return getattr(self, f"version{self.version}")() + + @property + def version1(self): + return self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.max_index, bandwidth_hz=1000 + ) + + def version2(self): + below16k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, + end_index=self.split20k, + bandwidth_hz=2000 + ) + + return below16k + below20k + self.above20k + + def version3(self): + below8k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=2000 + ) + + return below8k + below16k + self.above16k + + def version4(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split8k, + bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=2000 + ) + + return below1k + below8k + below16k + self.above16k + + def version5(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split16k, + bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, + end_index=self.split20k, + bandwidth_hz=2000 + ) + return below1k + below16k + below20k + self.above20k + + def version6(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split4k, + bandwidth_hz=500 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, + end_index=self.split8k, + bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=2000 + ) + return below1k + below4k + below8k + below16k + self.above16k + + def version7(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split4k, + bandwidth_hz=250 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, + end_index=self.split8k, + bandwidth_hz=500 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, + end_index=self.split20k, + bandwidth_hz=2000 + ) + return below1k + below4k + below8k + below16k + below20k + self.above20k + + +class OtherBandsplitSpecification(VocalBandsplitSpecification): + def __init__(self, nfft: int, fs: int) -> None: + super().__init__(nfft=nfft, fs=fs, version="7") + + +class BassBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int, version: str = "7") -> None: + super().__init__(nfft=nfft, fs=fs) + + def get_band_specs(self): + below500 = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split500, bandwidth_hz=50 + ) + below1k = self.get_band_specs_with_bandwidth( + start_index=self.split500, + end_index=self.split1k, + bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split4k, + bandwidth_hz=500 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, + end_index=self.split8k, + bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=2000 + ) + above16k = [(self.split16k, self.max_index)] + + return below500 + below1k + below4k + below8k + below16k + above16k + + +class DrumBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int) -> None: + super().__init__(nfft=nfft, fs=fs) + + def get_band_specs(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=50 + ) + below2k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split2k, + bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split2k, + end_index=self.split4k, + bandwidth_hz=250 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, + end_index=self.split8k, + bandwidth_hz=500 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=1000 + ) + above16k = [(self.split16k, self.max_index)] + + return below1k + below2k + below4k + below8k + below16k + above16k + + + + +class PerceptualBandsplitSpecification(BandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + fbank_fn: Callable[[int, int, float, float, int], torch.Tensor], + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(nfft=nfft, fs=fs) + self.n_bands = n_bands + if f_max is None: + f_max = fs / 2 + + self.filterbank = fbank_fn( + n_bands, fs, f_min, f_max, self.max_index + ) + + weight_per_bin = torch.sum( + self.filterbank, + dim=0, + keepdim=True + ) # (1, n_freqs) + normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs) + + freq_weights = [] + band_specs = [] + for i in range(self.n_bands): + active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist() + if isinstance(active_bins, int): + active_bins = (active_bins, active_bins) + if len(active_bins) == 0: + continue + start_index = active_bins[0] + end_index = active_bins[-1] + 1 + band_specs.append((start_index, end_index)) + freq_weights.append(normalized_mel_fb[i, start_index:end_index]) + + self.freq_weights = freq_weights + self.band_specs = band_specs + + def get_band_specs(self): + return self.band_specs + + def get_freq_weights(self): + return self.freq_weights + + def save_to_file(self, dir_path: str) -> None: + + os.makedirs(dir_path, exist_ok=True) + + import pickle + + with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f: + pickle.dump( + { + "band_specs": self.band_specs, + "freq_weights": self.freq_weights, + "filterbank": self.filterbank, + }, + f, + ) + +def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs): + fb = taF.melscale_fbanks( + n_mels=n_bands, + sample_rate=fs, + f_min=f_min, + f_max=f_max, + n_freqs=n_freqs, + ).T + + fb[0, 0] = 1.0 + + return fb + + +class MelBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=mel_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + +def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, + scale="constant"): + + nfft = 2 * (n_freqs - 1) + df = fs / nfft + # init freqs + f_max = f_max or fs / 2 + f_min = f_min or 0 + f_min = fs / nfft + + n_octaves = np.log2(f_max / f_min) + n_octaves_per_band = n_octaves / n_bands + bandwidth_mult = np.power(2.0, n_octaves_per_band) + + low_midi = max(0, hz_to_midi(f_min)) + high_midi = hz_to_midi(f_max) + midi_points = np.linspace(low_midi, high_midi, n_bands) + hz_pts = midi_to_hz(midi_points) + + low_pts = hz_pts / bandwidth_mult + high_pts = hz_pts * bandwidth_mult + + low_bins = np.floor(low_pts / df).astype(int) + high_bins = np.ceil(high_pts / df).astype(int) + + fb = np.zeros((n_bands, n_freqs)) + + for i in range(n_bands): + fb[i, low_bins[i]:high_bins[i]+1] = 1.0 + + fb[0, :low_bins[0]] = 1.0 + fb[-1, high_bins[-1]+1:] = 1.0 + + return torch.as_tensor(fb) + +class MusicalBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=musical_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + +def bark_filterbank( + n_bands, fs, f_min, f_max, n_freqs +): + nfft = 2 * (n_freqs -1) + fb, _ = bark_fbanks.bark_filter_banks( + nfilts=n_bands, + nfft=nfft, + fs=fs, + low_freq=f_min, + high_freq=f_max, + scale="constant" + ) + + return torch.as_tensor(fb) + +class BarkBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + +def triangular_bark_filterbank( + n_bands, fs, f_min, f_max, n_freqs +): + + all_freqs = torch.linspace(0, fs // 2, n_freqs) + + # calculate mel freq bins + m_min = hz2bark(f_min) + m_max = hz2bark(f_max) + + m_pts = torch.linspace(m_min, m_max, n_bands + 2) + f_pts = 600 * torch.sinh(m_pts / 6) + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + fb = fb.T + + first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] + first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] + + fb[first_active_band, :first_active_bin] = 1.0 + + return fb + +class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + + +def minibark_filterbank( + n_bands, fs, f_min, f_max, n_freqs +): + fb = bark_filterbank( + n_bands, + fs, + f_min, + f_max, + n_freqs + ) + + fb[fb < np.sqrt(0.5)] = 0.0 + + return fb + +class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + + + + +def erb_filterbank( + n_bands: int, + fs: int, + f_min: float, + f_max: float, + n_freqs: int, +) -> Tensor: + # freq bins + A = (1000 * np.log(10)) / (24.7 * 4.37) + all_freqs = torch.linspace(0, fs // 2, n_freqs) + + # calculate mel freq bins + m_min = hz2erb(f_min) + m_max = hz2erb(f_max) + + m_pts = torch.linspace(m_min, m_max, n_bands + 2) + f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437 + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + fb = fb.T + + + first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] + first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] + + fb[first_active_band, :first_active_bin] = 1.0 + + return fb + + + +class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + +if __name__ == "__main__": + import pandas as pd + + band_defs = [] + + for bands in [VocalBandsplitSpecification]: + band_name = bands.__name__.replace("BandsplitSpecification", "") + + mbs = bands(nfft=2048, fs=44100).get_band_specs() + + for i, (f_min, f_max) in enumerate(mbs): + band_defs.append({ + "band": band_name, + "band_index": i, + "f_min": f_min, + "f_max": f_max + }) + + df = pd.DataFrame(band_defs) + df.to_csv("vox7bands.csv", index=False) \ No newline at end of file diff --git a/models/bandit/core/model/bsrnn/wrapper.py b/models/bandit/core/model/bsrnn/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..d79d573502ec0a36e6ee5001b6e3d0e9a2a83cf8 --- /dev/null +++ b/models/bandit/core/model/bsrnn/wrapper.py @@ -0,0 +1,882 @@ +from pprint import pprint +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from models.bandit.core.model._spectral import _SpectralComponent +from models.bandit.core.model.bsrnn.utils import ( + BarkBandsplitSpecification, BassBandsplitSpecification, + DrumBandsplitSpecification, + EquivalentRectangularBandsplitSpecification, MelBandsplitSpecification, + MusicalBandsplitSpecification, OtherBandsplitSpecification, + TriangularBarkBandsplitSpecification, VocalBandsplitSpecification, +) +from .core import ( + MultiSourceMultiMaskBandSplitCoreConv, + MultiSourceMultiMaskBandSplitCoreRNN, + MultiSourceMultiMaskBandSplitCoreTransformer, + MultiSourceMultiPatchingMaskBandSplitCoreRNN, SingleMaskBandsplitCoreRNN, + SingleMaskBandsplitCoreTransformer, +) + +import pytorch_lightning as pl + +def get_band_specs(band_specs, n_fft, fs, n_bands=None): + if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]: + bsm = VocalBandsplitSpecification( + nfft=n_fft, fs=fs + ).get_band_specs() + freq_weights = None + overlapping_band = False + elif "tribark" in band_specs: + assert n_bands is not None + specs = TriangularBarkBandsplitSpecification( + nfft=n_fft, + fs=fs, + n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif "bark" in band_specs: + assert n_bands is not None + specs = BarkBandsplitSpecification( + nfft=n_fft, + fs=fs, + n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif "erb" in band_specs: + assert n_bands is not None + specs = EquivalentRectangularBandsplitSpecification( + nfft=n_fft, + fs=fs, + n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif "musical" in band_specs: + assert n_bands is not None + specs = MusicalBandsplitSpecification( + nfft=n_fft, + fs=fs, + n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif band_specs == "dnr:mel" or "mel" in band_specs: + assert n_bands is not None + specs = MelBandsplitSpecification( + nfft=n_fft, + fs=fs, + n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + else: + raise NameError + + return bsm, freq_weights, overlapping_band + + +def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None): + if band_specs_map == "musdb:all": + bsm = { + "vocals": VocalBandsplitSpecification( + nfft=n_fft, fs=fs + ).get_band_specs(), + "drums": DrumBandsplitSpecification( + nfft=n_fft, fs=fs + ).get_band_specs(), + "bass": BassBandsplitSpecification( + nfft=n_fft, fs=fs + ).get_band_specs(), + "other": OtherBandsplitSpecification( + nfft=n_fft, fs=fs + ).get_band_specs(), + } + freq_weights = None + overlapping_band = False + elif band_specs_map == "dnr:vox7": + bsm_, freq_weights, overlapping_band = get_band_specs( + "dnr:speech", n_fft, fs, n_bands + ) + bsm = { + "speech": bsm_, + "music": bsm_, + "effects": bsm_ + } + elif "dnr:vox7:" in band_specs_map: + stem = band_specs_map.split(":")[-1] + bsm_, freq_weights, overlapping_band = get_band_specs( + "dnr:speech", n_fft, fs, n_bands + ) + bsm = { + stem: bsm_ + } + else: + raise NameError + + return bsm, freq_weights, overlapping_band + + +class BandSplitWrapperBase(pl.LightningModule): + bsrnn: nn.Module + + def __init__(self, **kwargs): + super().__init__() + + +class SingleMaskMultiSourceBandSplitBase( + BandSplitWrapperBase, + _SpectralComponent +): + def __init__( + self, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + if isinstance(band_specs_map, str): + self.band_specs_map, self.freq_weights, self.overlapping_band = get_band_specs_map( + band_specs_map, + n_fft, + fs, + n_bands=n_bands + ) + + self.stems = list(self.band_specs_map.keys()) + + def forward(self, batch): + audio = batch["audio"] + + with torch.no_grad(): + batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in + audio} + + X = batch["spectrogram"]["mixture"] + length = batch["audio"]["mixture"].shape[-1] + + output = {"spectrogram": {}, "audio": {}} + + for stem, bsrnn in self.bsrnn.items(): + S = bsrnn(X) + s = self.istft(S, length) + output["spectrogram"][stem] = S + output["audio"][stem] = s + + return batch, output + + +class MultiMaskMultiSourceBandSplitBase( + BandSplitWrapperBase, + _SpectralComponent +): + def __init__( + self, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + if isinstance(band_specs, str): + self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs( + band_specs, + n_fft, + fs, + n_bands + ) + + self.stems = stems + + def forward(self, batch): + # with torch.no_grad(): + audio = batch["audio"] + cond = batch.get("condition", None) + with torch.no_grad(): + batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in + audio} + + X = batch["spectrogram"]["mixture"] + length = batch["audio"]["mixture"].shape[-1] + + output = self.bsrnn(X, cond=cond) + output["audio"] = {} + + for stem, S in output["spectrogram"].items(): + s = self.istft(S, length) + output["audio"][stem] = s + + return batch, output + + +class MultiMaskMultiSourceBandSplitBaseSimple( + BandSplitWrapperBase, + _SpectralComponent +): + def __init__( + self, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + if isinstance(band_specs, str): + self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs( + band_specs, + n_fft, + fs, + n_bands + ) + + self.stems = stems + + def forward(self, batch): + with torch.no_grad(): + X = self.stft(batch) + length = batch.shape[-1] + output = self.bsrnn(X, cond=None) + res = [] + for stem, S in output["spectrogram"].items(): + s = self.istft(S, length) + res.append(s) + res = torch.stack(res, dim=1) + return res + + +class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + ) -> None: + super().__init__( + band_specs_map=band_specs_map, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + self.bsrnn = nn.ModuleDict( + { + src: SingleMaskBandsplitCoreRNN( + band_specs=specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for src, specs in self.band_specs_map.items() + } + ) + + +class SingleMaskMultiSourceBandSplitTransformer( + SingleMaskMultiSourceBandSplitBase +): + def __init__( + self, + in_channel: int, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + ) -> None: + super().__init__( + band_specs_map=band_specs_map, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + self.bsrnn = nn.ModuleDict( + { + src: SingleMaskBandsplitCoreTransformer( + band_specs=specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + tf_dropout=tf_dropout, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for src, specs in self.band_specs_map.items() + } + ) + + +class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, + freeze_encoder: bool = False, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + self.normalize_input = normalize_input + self.cond_dim = cond_dim + + if freeze_encoder: + for param in self.bsrnn.band_split.parameters(): + param.requires_grad = False + + for param in self.bsrnn.tf_model.parameters(): + param.requires_grad = False + + +class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, + freeze_encoder: bool = False, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + self.normalize_input = normalize_input + self.cond_dim = cond_dim + + if freeze_encoder: + for param in self.bsrnn.band_split.parameters(): + param.requires_grad = False + + for param in self.bsrnn.tf_model.parameters(): + param.requires_grad = False + + +class MultiMaskMultiSourceBandSplitTransformer( + MultiMaskMultiSourceBandSplitBase +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + + +class MultiMaskMultiSourceBandSplitConv( + MultiMaskMultiSourceBandSplitBase +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) +class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + kernel_norm_mlp_version: int = 1, + mask_kernel_freq: int = 3, + mask_kernel_time: int = 3, + conv_kernel_freq: int = 1, + conv_kernel_time: int = 1, + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + mask_kernel_freq=mask_kernel_freq, + mask_kernel_time=mask_kernel_time, + conv_kernel_freq=conv_kernel_freq, + conv_kernel_time=conv_kernel_time, + kernel_norm_mlp_version=kernel_norm_mlp_version, + ) diff --git a/models/bandit/core/utils/__init__.py b/models/bandit/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/bandit/core/utils/audio.py b/models/bandit/core/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..60a29a2e8098f8d00f177863693898ca71dacd3c --- /dev/null +++ b/models/bandit/core/utils/audio.py @@ -0,0 +1,463 @@ +from collections import defaultdict + +from tqdm import tqdm +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +@torch.jit.script +def merge( + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_chunks: int, + chunk_size: int, ): + combined = torch.reshape( + combined, + (original_batch_size, n_chunks, n_channel, chunk_size) + ) + combined = torch.permute(combined, (0, 2, 3, 1)).reshape( + original_batch_size * n_channel, + chunk_size, + n_chunks + ) + + return combined + + +@torch.jit.script +def unfold( + padded_audio: torch.Tensor, + original_batch_size: int, + n_channel: int, + chunk_size: int, + hop_size: int + ) -> torch.Tensor: + + unfolded_input = F.unfold( + padded_audio[:, :, None, :], + kernel_size=(1, chunk_size), + stride=(1, hop_size) + ) + + _, _, n_chunks = unfolded_input.shape + unfolded_input = unfolded_input.view( + original_batch_size, + n_channel, + chunk_size, + n_chunks + ) + unfolded_input = torch.permute( + unfolded_input, + (0, 3, 1, 2) + ).reshape( + original_batch_size * n_chunks, + n_channel, + chunk_size + ) + + return unfolded_input + + +@torch.jit.script +# @torch.compile +def merge_chunks_all( + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_samples: int, + n_padded_samples: int, + n_chunks: int, + chunk_size: int, + hop_size: int, + edge_frame_pad_sizes: Tuple[int, int], + standard_window: torch.Tensor, + first_window: torch.Tensor, + last_window: torch.Tensor +): + combined = merge( + combined, + original_batch_size, + n_channel, + n_chunks, + chunk_size + ) + + combined = combined * standard_window[:, None].to(combined.device) + + combined = F.fold( + combined.to(torch.float32), output_size=(1, n_padded_samples), + kernel_size=(1, chunk_size), + stride=(1, hop_size) + ) + + combined = combined.view( + original_batch_size, + n_channel, + n_padded_samples + ) + + pad_front, pad_back = edge_frame_pad_sizes + combined = combined[..., pad_front:-pad_back] + + combined = combined[..., :n_samples] + + return combined + + # @torch.jit.script + + +def merge_chunks_edge( + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_samples: int, + n_padded_samples: int, + n_chunks: int, + chunk_size: int, + hop_size: int, + edge_frame_pad_sizes: Tuple[int, int], + standard_window: torch.Tensor, + first_window: torch.Tensor, + last_window: torch.Tensor +): + combined = merge( + combined, + original_batch_size, + n_channel, + n_chunks, + chunk_size + ) + + combined[..., 0] = combined[..., 0] * first_window + combined[..., -1] = combined[..., -1] * last_window + combined[..., 1:-1] = combined[..., + 1:-1] * standard_window[:, None] + + combined = F.fold( + combined, output_size=(1, n_padded_samples), + kernel_size=(1, chunk_size), + stride=(1, hop_size) + ) + + combined = combined.view( + original_batch_size, + n_channel, + n_padded_samples + ) + + combined = combined[..., :n_samples] + + return combined + + +class BaseFader(nn.Module): + def __init__( + self, + chunk_size_second: float, + hop_size_second: float, + fs: int, + fade_edge_frames: bool, + batch_size: int, + ) -> None: + super().__init__() + + self.chunk_size = int(chunk_size_second * fs) + self.hop_size = int(hop_size_second * fs) + self.overlap_size = self.chunk_size - self.hop_size + self.fade_edge_frames = fade_edge_frames + self.batch_size = batch_size + + # @torch.jit.script + def prepare(self, audio): + + if self.fade_edge_frames: + audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect") + + n_samples = audio.shape[-1] + n_chunks = int( + np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1 + ) + + padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size + pad_size = padded_size - n_samples + + padded_audio = F.pad(audio, (0, pad_size)) + + return padded_audio, n_chunks + + def forward( + self, + audio: torch.Tensor, + model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]], + ): + + original_dtype = audio.dtype + original_device = audio.device + + audio = audio.to("cpu") + + original_batch_size, n_channel, n_samples = audio.shape + padded_audio, n_chunks = self.prepare(audio) + del audio + n_padded_samples = padded_audio.shape[-1] + + if n_channel > 1: + padded_audio = padded_audio.view( + original_batch_size * n_channel, 1, n_padded_samples + ) + + unfolded_input = unfold( + padded_audio, + original_batch_size, + n_channel, + self.chunk_size, self.hop_size + ) + + n_total_chunks, n_channel, chunk_size = unfolded_input.shape + + n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int) + + chunks_in = [ + unfolded_input[ + b * self.batch_size:(b + 1) * self.batch_size, ...].clone() + for b in range(n_batch) + ] + + all_chunks_out = defaultdict( + lambda: torch.zeros_like( + unfolded_input, device="cpu" + ) + ) + + # for b, cin in enumerate(tqdm(chunks_in)): + for b, cin in enumerate(chunks_in): + if torch.allclose(cin, torch.tensor(0.0)): + del cin + continue + + chunks_out = model_fn(cin.to(original_device)) + del cin + for s, c in chunks_out.items(): + all_chunks_out[s][b * self.batch_size:(b + 1) * self.batch_size, + ...] = c.cpu() + del chunks_out + + del unfolded_input + del padded_audio + + if self.fade_edge_frames: + fn = merge_chunks_all + else: + fn = merge_chunks_edge + outputs = {} + + torch.cuda.empty_cache() + + for s, c in all_chunks_out.items(): + combined: torch.Tensor = fn( + c, + original_batch_size, + n_channel, + n_samples, + n_padded_samples, + n_chunks, + self.chunk_size, + self.hop_size, + self.edge_frame_pad_sizes, + self.standard_window, + self.__dict__.get("first_window", self.standard_window), + self.__dict__.get("last_window", self.standard_window) + ) + + outputs[s] = combined.to( + dtype=original_dtype, + device=original_device + ) + + return { + "audio": outputs + } + # + # def old_forward( + # self, + # audio: torch.Tensor, + # model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]], + # ): + # + # n_samples = audio.shape[-1] + # original_batch_size = audio.shape[0] + # + # padded_audio, n_chunks = self.prepare(audio) + # + # ndim = padded_audio.ndim + # broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size] + # + # outputs = defaultdict( + # lambda: torch.zeros_like( + # padded_audio, device=audio.device, dtype=torch.float64 + # ) + # ) + # + # all_chunks_out = [] + # len_chunks_in = [] + # + # batch_size_ = int(self.batch_size // original_batch_size) + # for b in range(int(np.ceil(n_chunks / batch_size_))): + # chunks_in = [] + # for j in range(batch_size_): + # i = b * batch_size_ + j + # if i == n_chunks: + # break + # + # start = i * hop_size + # end = start + self.chunk_size + # chunk_in = padded_audio[..., start:end] + # chunks_in.append(chunk_in) + # + # chunks_in = torch.concat(chunks_in, dim=0) + # chunks_out = model_fn(chunks_in) + # all_chunks_out.append(chunks_out) + # len_chunks_in.append(len(chunks_in)) + # + # for b, (chunks_out, lci) in enumerate( + # zip(all_chunks_out, len_chunks_in) + # ): + # for stem in chunks_out: + # for j in range(lci // original_batch_size): + # i = b * batch_size_ + j + # + # if self.fade_edge_frames: + # window = self.standard_window + # else: + # if i == 0: + # window = self.first_window + # elif i == n_chunks - 1: + # window = self.last_window + # else: + # window = self.standard_window + # + # start = i * hop_size + # end = start + self.chunk_size + # + # chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size, + # ...] + # contrib = window.view(*broadcaster) * chunk_out + # outputs[stem][..., start:end] = ( + # outputs[stem][..., start:end] + contrib + # ) + # + # if self.fade_edge_frames: + # pad_front, pad_back = self.edge_frame_pad_sizes + # outputs = {k: v[..., pad_front:-pad_back] for k, v in + # outputs.items()} + # + # outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in + # outputs.items()} + # + # return { + # "audio": outputs + # } + + +class LinearFader(BaseFader): + def __init__( + self, + chunk_size_second: float, + hop_size_second: float, + fs: int, + fade_edge_frames: bool = False, + batch_size: int = 1, + ) -> None: + + assert hop_size_second >= chunk_size_second / 2 + + super().__init__( + chunk_size_second=chunk_size_second, + hop_size_second=hop_size_second, + fs=fs, + fade_edge_frames=fade_edge_frames, + batch_size=batch_size, + ) + + in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1] + out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:] + center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size) + inout_ones = torch.ones(self.overlap_size) + + # using nn.Parameters allows lightning to take care of devices for us + self.register_buffer( + "standard_window", + torch.concat([in_fade, center_ones, out_fade]) + ) + + self.fade_edge_frames = fade_edge_frames + self.edge_frame_pad_size = (self.overlap_size, self.overlap_size) + + if not self.fade_edge_frames: + self.first_window = nn.Parameter( + torch.concat([inout_ones, center_ones, out_fade]), + requires_grad=False + ) + self.last_window = nn.Parameter( + torch.concat([in_fade, center_ones, inout_ones]), + requires_grad=False + ) + + +class OverlapAddFader(BaseFader): + def __init__( + self, + window_type: str, + chunk_size_second: float, + hop_size_second: float, + fs: int, + batch_size: int = 1, + ) -> None: + assert (chunk_size_second / hop_size_second) % 2 == 0 + assert int(chunk_size_second * fs) % 2 == 0 + + super().__init__( + chunk_size_second=chunk_size_second, + hop_size_second=hop_size_second, + fs=fs, + fade_edge_frames=True, + batch_size=batch_size, + ) + + self.hop_multiplier = self.chunk_size / (2 * self.hop_size) + # print(f"hop multiplier: {self.hop_multiplier}") + + self.edge_frame_pad_sizes = ( + 2 * self.overlap_size, + 2 * self.overlap_size + ) + + self.register_buffer( + "standard_window", torch.windows.__dict__[window_type]( + self.chunk_size, sym=False, # dtype=torch.float64 + ) / self.hop_multiplier + ) + + +if __name__ == "__main__": + import torchaudio as ta + fs = 44100 + ola = OverlapAddFader( + "hann", + 6.0, + 1.0, + fs, + batch_size=16 + ) + audio_, _ = ta.load( + "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " + "Much/vocals.wav" + ) + audio_ = audio_[None, ...] + out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"] + print(torch.allclose(out, audio_)) diff --git a/models/bandit/model_from_config.py b/models/bandit/model_from_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7cd99030b043b780bb4c38f2f868cfab6329aba9 --- /dev/null +++ b/models/bandit/model_from_config.py @@ -0,0 +1,31 @@ +import sys +import os.path +import torch + +code_path = os.path.dirname(os.path.abspath(__file__)) + '/' +sys.path.append(code_path) + +import yaml +from ml_collections import ConfigDict + +torch.set_float32_matmul_precision("medium") + + +def get_model( + config_path, + weights_path, + device, +): + from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple + + f = open(config_path) + config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) + f.close() + + model = MultiMaskMultiSourceBandSplitRNNSimple( + **config.model + ) + d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt') + model.load_state_dict(d) + model.to(device) + return model, config diff --git a/models/bs_roformer/__init__.py b/models/bs_roformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d271909b8aef18fdabb89763f34a279952e7c06b --- /dev/null +++ b/models/bs_roformer/__init__.py @@ -0,0 +1,2 @@ +from models.bs_roformer.bs_roformer import BSRoformer +from models.bs_roformer.mel_band_roformer import MelBandRoformer diff --git a/models/bs_roformer/attend.py b/models/bs_roformer/attend.py new file mode 100644 index 0000000000000000000000000000000000000000..899d26f6fbed25cfab78181237bf14bc270608ae --- /dev/null +++ b/models/bs_roformer/attend.py @@ -0,0 +1,120 @@ +from functools import wraps +from packaging import version +from collections import namedtuple + +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, reduce + +# constants + +FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) + +# helpers + +def exists(val): + return val is not None + +def default(v, d): + return v if exists(v) else d + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + +# main class + +class Attend(nn.Module): + def __init__( + self, + dropout = 0., + flash = False, + scale = None + ): + super().__init__() + self.scale = scale + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' + + # determine efficient attention configs for cuda and cpu + + self.cpu_config = FlashAttentionConfig(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once('A100 GPU detected, using flash attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(True, False, False) + else: + print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(False, True, True) + + def flash_attn(self, q, k, v): + _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device + + if exists(self.scale): + default_scale = q.shape[-1] ** -0.5 + q = q * (self.scale / default_scale) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p = self.dropout if self.training else 0. + ) + + return out + + def forward(self, q, k, v): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + q_len, k_len, device = q.shape[-2], k.shape[-2], q.device + + scale = default(self.scale, q.shape[-1] ** -0.5) + + if self.flash: + return self.flash_attn(q, k, v) + + # similarity + + sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, b h j d -> b h i d", attn, v) + + return out diff --git a/models/bs_roformer/bs_roformer.py b/models/bs_roformer/bs_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..65b661e1ad10c0abada560168036b14131768abc --- /dev/null +++ b/models/bs_roformer/bs_roformer.py @@ -0,0 +1,577 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from models.bs_roformer.attend import Attend + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack +from einops.layers.torch import Rearrange + +# helper functions + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# norm + +def l2norm(t): + return F.normalize(t, dim = -1, p = 2) + + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + +class FeedForward(Module): + def __init__( + self, + dim, + mult=4, + dropout=0. + ): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0., + rotary_embed=None, + flash=True + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), + nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0. + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads) + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = Attend( + scale=scale, + dropout=dropout, + flash=flash + ) + + self.to_out = nn.Sequential( + Rearrange('b h d n -> b n (h d)'), + nn.Linear(dim_inner, dim, bias=False) + ) + + def forward( + self, + x + ): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) + else: + attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, + rotary_embed=rotary_embed, flash=flash_attn) + + self.layers.append(ModuleList([ + attn, + FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + ])) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + +class BandSplit(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in, + dim_out, + dim_hidden=None, + depth=1, + activation=nn.Tanh +): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...], + depth, + mlp_expansion_factor=4 + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), + nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + +DEFAULT_FREQS_PER_BANDS = ( + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 12, 12, 12, 12, 12, 12, 12, 12, + 24, 24, 24, 24, 24, 24, 24, 24, + 48, 48, 48, 48, 48, 48, 48, 48, + 128, 129, +) + + +class BSRoformer(Module): + + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, + # in the paper, they divide into ~60 bands, test with 1 for starters + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + flash_attn=True, + dim_freqs_in=1025, + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=2, + multi_stft_resolution_loss_weight=1., + multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + + self.layers = ModuleList([]) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn, + norm_output=False + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) + tran_modules.append( + Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) + ) + tran_modules.append( + Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.final_norm = RMSNorm(dim) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized + ) + + self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) + + freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1] + + assert len(freqs_per_bands) > 1 + assert sum( + freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}' + + freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands) + + self.band_split = BandSplit( + dim=dim, + dim_inputs=freqs_per_bands_with_complex + ) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, + normalized=multi_stft_normalized + ) + + def forward( + self, + raw_audio, + target=None, + return_loss_breakdown=False + ): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + + channels = raw_audio.shape[1] + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + + stft_window = self.stft_window_fn(device=device) + + stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') + stft_repr = rearrange(stft_repr, + 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + + x = rearrange(stft_repr, 'b f t c -> b t (f c)') + + x = self.band_split(x) + + # axial / hierarchical attention + + for transformer_block in self.layers: + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = transformer_block + + x, ft_ps = pack([x], 'b * d') + x = linear_transformer(x) + x, = unpack(x, ft_ps, 'b * d') + else: + time_transformer, freq_transformer = transformer_block + + x = rearrange(x, 'b t f d -> b f t d') + x, ps = pack([x], '* t d') + + x = time_transformer(x) + + x, = unpack(x, ps, '* t d') + x = rearrange(x, 'b f t d -> b t f d') + x, ps = pack([x], '* f d') + + x = freq_transformer(x) + + x, = unpack(x, ps, '* f d') + + x = self.final_norm(x) + + num_stems = len(self.mask_estimators) + + mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + mask = torch.view_as_complex(mask) + + stft_repr = stft_repr * mask + + # istft + + stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + + recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False) + + recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, '... t -> ... 1 t') + + target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0. + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) + target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + + weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) \ No newline at end of file diff --git a/models/bs_roformer/mel_band_roformer.py b/models/bs_roformer/mel_band_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8808063b14142c62d5a6117e845ce152968d5e1d --- /dev/null +++ b/models/bs_roformer/mel_band_roformer.py @@ -0,0 +1,637 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from models.bs_roformer.attend import Attend + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack, reduce, repeat +from einops.layers.torch import Rearrange + +from librosa import filters + + +# helper functions + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def pad_at_dim(t, pad, dim=-1, value=0.): + dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = ((0, 0) * dims_from_right) + return F.pad(t, (*zeros, *pad), value=value) + + +def l2norm(t): + return F.normalize(t, dim=-1, p=2) + + +# norm + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + +class FeedForward(Module): + def __init__( + self, + dim, + mult=4, + dropout=0. + ): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0., + rotary_embed=None, + flash=True + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), + nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0. + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads) + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = Attend( + scale=scale, + dropout=dropout, + flash=flash + ) + + self.to_out = nn.Sequential( + Rearrange('b h d n -> b n (h d)'), + nn.Linear(dim_inner, dim, bias=False) + ) + + def forward( + self, + x + ): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) + else: + attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, + rotary_embed=rotary_embed, flash=flash_attn) + + self.layers.append(ModuleList([ + attn, + FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + ])) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + +class BandSplit(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in, + dim_out, + dim_hidden=None, + depth=1, + activation=nn.Tanh +): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * depth), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...], + depth, + mlp_expansion_factor=4 + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), + nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + +class MelBandRoformer(Module): + + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + num_bands=60, + dim_head=64, + heads=8, + attn_dropout=0.1, + ff_dropout=0.1, + flash_attn=True, + dim_freqs_in=1025, + sample_rate=44100, # needed for mel filter bank from librosa + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=1, + multi_stft_resolution_loss_weight=1., + multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + match_input_audio_length=False, # if True, pad output tensor to match length of input tensor + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + + self.layers = ModuleList([]) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) + tran_modules.append( + Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) + ) + tran_modules.append( + Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized + ) + + freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1] + + # create mel filter bank + # with librosa.filters.mel as in section 2 of paper + + mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands) + + mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy) + + # for some reason, it doesn't include the first freq? just force a value for now + + mel_filter_bank[0][0] = 1. + + # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position, + # so let's force a positive value + + mel_filter_bank[-1, -1] = 1. + + # binary as in paper (then estimated masks are averaged for overlapping regions) + + freqs_per_band = mel_filter_bank > 0 + assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now' + + repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands) + freq_indices = repeated_freq_indices[freqs_per_band] + + if stereo: + freq_indices = repeat(freq_indices, 'f -> f s', s=2) + freq_indices = freq_indices * 2 + torch.arange(2) + freq_indices = rearrange(freq_indices, 'f s -> (f s)') + + self.register_buffer('freq_indices', freq_indices, persistent=False) + self.register_buffer('freqs_per_band', freqs_per_band, persistent=False) + + num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum') + num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum') + + self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False) + self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False) + + # band split and mask estimator + + freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist()) + + self.band_split = BandSplit( + dim=dim, + dim_inputs=freqs_per_bands_with_complex + ) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, + normalized=multi_stft_normalized + ) + + self.match_input_audio_length = match_input_audio_length + + def forward( + self, + raw_audio, + target=None, + return_loss_breakdown=False + ): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + + batch, channels, raw_audio_length = raw_audio.shape + + istft_length = raw_audio_length if self.match_input_audio_length else None + + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + + stft_window = self.stft_window_fn(device=device) + + stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') + stft_repr = rearrange(stft_repr, + 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + + # index out all frequencies for all frequency ranges across bands ascending in one go + + batch_arange = torch.arange(batch, device=device)[..., None] + + # account for stereo + + x = stft_repr[batch_arange, self.freq_indices] + + # fold the complex (real and imag) into the frequencies dimension + + x = rearrange(x, 'b f t c -> b t (f c)') + + x = self.band_split(x) + + # axial / hierarchical attention + + for transformer_block in self.layers: + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = transformer_block + + x, ft_ps = pack([x], 'b * d') + x = linear_transformer(x) + x, = unpack(x, ft_ps, 'b * d') + else: + time_transformer, freq_transformer = transformer_block + + x = rearrange(x, 'b t f d -> b f t d') + x, ps = pack([x], '* t d') + + x = time_transformer(x) + + x, = unpack(x, ps, '* t d') + x = rearrange(x, 'b f t d -> b t f d') + x, ps = pack([x], '* f d') + + x = freq_transformer(x) + + x, = unpack(x, ps, '* f d') + + num_stems = len(self.mask_estimators) + + masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + masks = torch.view_as_complex(masks) + + masks = masks.type(stft_repr.dtype) + + # need to average the estimated mask for the overlapped frequencies + + scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) + + stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems) + masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks) + + denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels) + + masks_averaged = masks_summed / denom.clamp(min=1e-8) + + # modulate stft repr with estimated mask + + stft_repr = stft_repr * masks_averaged + + # istft + + stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + + recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, + length=istft_length) + + recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, '... t -> ... 1 t') + + target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0. + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) + target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + + weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) diff --git a/models/demucs4ht.py b/models/demucs4ht.py new file mode 100644 index 0000000000000000000000000000000000000000..e80d34afc6847086bf62c6683679dc0971ce489f --- /dev/null +++ b/models/demucs4ht.py @@ -0,0 +1,713 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +import numpy as np +import torch +import json +from omegaconf import OmegaConf +from demucs.demucs import Demucs +from demucs.hdemucs import HDemucs + +import math +from openunmix.filtering import wiener +from torch import nn +from torch.nn import functional as F +from fractions import Fraction +from einops import rearrange + +from demucs.transformer import CrossTransformerEncoder + +from demucs.demucs import rescale_module +from demucs.states import capture_init +from demucs.spec import spectro, ispectro +from demucs.hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer + + +class HTDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + + @capture_init + def __init__( + self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + num_subbands=1, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=4, + rewrite=True, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=3, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=8, + dconv_init=1e-3, + # Before the Transformer + bottom_channels=0, + # Transformer + t_layers=5, + t_emb="sin", + t_hidden_scale=4.0, + t_heads=8, + t_dropout=0.0, + t_max_positions=10000, + t_norm_in=True, + t_norm_in_group=False, + t_group_norm=False, + t_norm_first=True, + t_norm_out=True, + t_max_period=10000.0, + t_weight_decay=0.0, + t_lr=None, + t_layer_scale=True, + t_gelu=True, + t_weight_pos_embed=1.0, + t_sin_random_shift=0, + t_cape_mean_normalize=True, + t_cape_augment=True, + t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], + t_sparse_self_attn=False, + t_sparse_cross_attn=False, + t_mask_type="diag", + t_mask_random_seed=42, + t_sparse_attn_window=500, + t_global_window=100, + t_sparsity=0.95, + t_auto_sparsity=False, + # ------ Particuliar parameters + t_cross_first=False, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=10, + use_train_segment=False, + ): + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the + transformer in order to change the number of channels + t_layers: number of layers in each branch (waveform and spec) of the transformer + t_emb: "sin", "cape" or "scaled" + t_hidden_scale: the hidden scale of the Feedforward parts of the transformer + for instance if C = 384 (the number of channels in the transformer) and + t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension + 384 * 4 = 1536 + t_heads: number of heads for the transformer + t_dropout: dropout in the transformer + t_max_positions: max_positions for the "scaled" positional embedding, only + useful if t_emb="scaled" + t_norm_in: (bool) norm before addinf positional embedding and getting into the + transformer layers + t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the + timesteps (GroupNorm with group=1) + t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the + timesteps (GroupNorm with group=1) + t_norm_first: (bool) if True the norm is before the attention and before the FFN + t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer + t_max_period: (float) denominator in the sinusoidal embedding expression + t_weight_decay: (float) weight decay for the transformer + t_lr: (float) specific learning rate for the transformer + t_layer_scale: (bool) Layer Scale for the transformer + t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else + t_weight_pos_embed: (float) weighting of the positional embedding + t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings + see: https://arxiv.org/abs/2106.03143 + t_cape_augment: (bool) if t_emb="cape", must be True during training and False + during the inference, see: https://arxiv.org/abs/2106.03143 + t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters + see: https://arxiv.org/abs/2106.03143 + t_sparse_self_attn: (bool) if True, the self attentions are sparse + t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it + unless you designed really specific masks) + t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination + with '_' between: i.e. "diag_jmask_random" (note that this is permutation + invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag") + t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed + that generated the random part of the mask + t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and + a key (j), the mask is True id |i-j|<=t_sparse_attn_window + t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :] + and mask[:, :t_global_window] will be True + t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity + level of the random part of the mask. + t_cross_first: (bool) if True cross attention is the first layer of the + transformer (False seems to be better) + rescale: weight rescaling trick + use_train_segment: (bool) if True, the actual size that is used during the + training is used during inference. + """ + super().__init__() + self.num_subbands = num_subbands + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.bottom_channels = bottom_channels + self.channels = channels + self.samplerate = samplerate + self.segment = segment + self.use_train_segment = use_train_segment + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + if self.num_subbands > 1: + chin_z *= self.num_subbands + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + "kernel_size": ker, + "stride": stri, + "freq": freq, + "pad": pad, + "norm": norm, + "rewrite": rewrite, + "norm_groups": norm_groups, + "dconv_kw": { + "depth": dconv_depth, + "compress": dconv_comp, + "init": dconv_init, + "gelu": True, + }, + } + kwt = dict(kw) + kwt["freq"] = 0 + kwt["kernel_size"] = kernel_size + kwt["stride"] = stride + kwt["pad"] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec["context_freq"] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer( + chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw + ) + if freq: + tenc = HEncLayer( + chin, + chout, + dconv=dconv_mode & 1, + context=context_enc, + empty=last_freq, + **kwt + ) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + if self.num_subbands > 1: + chin_z *= self.num_subbands + dec = HDecLayer( + chout_z, + chin_z, + dconv=dconv_mode & 2, + last=index == 0, + context=context, + **kw_dec + ) + if multi: + dec = MultiWrap(dec, multi_freqs) + if freq: + tdec = HDecLayer( + chout, + chin, + dconv=dconv_mode & 2, + empty=last_freq, + last=index == 0, + context=context, + **kwt + ) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale + ) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + transformer_channels = channels * growth ** (depth - 1) + if bottom_channels: + self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) + self.channel_downsampler = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + self.channel_upsampler_t = nn.Conv1d( + transformer_channels, bottom_channels, 1 + ) + self.channel_downsampler_t = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + + transformer_channels = bottom_channels + + if t_layers > 0: + self.crosstransformer = CrossTransformerEncoder( + dim=transformer_channels, + emb=t_emb, + hidden_scale=t_hidden_scale, + num_heads=t_heads, + num_layers=t_layers, + cross_first=t_cross_first, + dropout=t_dropout, + max_positions=t_max_positions, + norm_in=t_norm_in, + norm_in_group=t_norm_in_group, + group_norm=t_group_norm, + norm_first=t_norm_first, + norm_out=t_norm_out, + max_period=t_max_period, + weight_decay=t_weight_decay, + lr=t_lr, + layer_scale=t_layer_scale, + gelu=t_gelu, + sin_random_shift=t_sin_random_shift, + weight_pos_embed=t_weight_pos_embed, + cape_mean_normalize=t_cape_mean_normalize, + cape_augment=t_cape_augment, + cape_glob_loc_scale=t_cape_glob_loc_scale, + sparse_self_attn=t_sparse_self_attn, + sparse_cross_attn=t_sparse_cross_attn, + mask_type=t_mask_type, + mask_random_seed=t_mask_random_seed, + sparse_attn_window=t_sparse_attn_window, + global_window=t_global_window, + sparsity=t_sparsity, + auto_sparsity=t_auto_sparsity, + ) + else: + self.crosstransformer = None + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") + + z = spectro(x, nfft, hl)[..., :-1, :] + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2: 2 + le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4**scale) + z = F.pad(z, (0, 0, 0, 1)) + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + le = hl * int(math.ceil(length / hl)) + 2 * pad + x = ispectro(z, hl, length=le) + x = x[..., pad: pad + length] + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], + mix_stft[sample, frame], + niters, + residual=residual, + ) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def valid_length(self, length: int): + """ + Return a length that is appropriate for evaluation. + In our case, always return the training length, unless + it is smaller than the given length, in which case this + raises an error. + """ + if not self.use_train_segment: + return length + training_length = int(self.segment * self.samplerate) + if training_length < length: + raise ValueError( + f"Given length {length} is longer than " + f"training length {training_length}") + return training_length + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, mix): + length = mix.shape[-1] + length_pre_pad = None + if self.use_train_segment: + if self.training: + self.segment = Fraction(mix.shape[-1], self.samplerate) + else: + training_length = int(self.segment * self.samplerate) + # print('Training length: {} Segment: {} Sample rate: {}'.format(training_length, self.segment, self.samplerate)) + if mix.shape[-1] < training_length: + length_pre_pad = mix.shape[-1] + mix = F.pad(mix, (0, training_length - length_pre_pad)) + # print("Mix: {}".format(mix.shape)) + # print("Length: {}".format(length)) + z = self._spec(mix) + # print("Z: {} Type: {}".format(z.shape, z.dtype)) + mag = self._magnitude(z) + x = mag + # print("MAG: {} Type: {}".format(x.shape, x.dtype)) + + if self.num_subbands > 1: + x = self.cac2cws(x) + # print("After SUBBANDS: {} Type: {}".format(x.shape, x.dtype)) + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # print("XT: {}".format(xt.shape)) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + # print("Encode XT {}: {}".format(idx, xt.shape)) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + # print("Encode X {}: {}".format(idx, x.shape)) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + if self.crosstransformer: + if self.bottom_channels: + b, c, f, t = x.shape + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_upsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_upsampler_t(xt) + + x, xt = self.crosstransformer(x, xt) + # print("Cross Tran X {}, XT: {}".format(x.shape, xt.shape)) + + if self.bottom_channels: + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_downsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_downsampler_t(xt) + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # print('Decode {} X: {}'.format(idx, x.shape)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + offset = self.depth - len(self.tdecoder) + if idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + # print('Decode {} XT: {}'.format(idx, xt.shape)) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + + if self.num_subbands > 1: + x = x.view(B, -1, Fq, T) + # print("X view 1: {}".format(x.shape)) + x = self.cws2cac(x) + # print("X view 2: {}".format(x.shape)) + + x = x.view(B, S, -1, Fq * self.num_subbands, T) + x = x * std[:, None] + mean[:, None] + # print("X returned: {}".format(x.shape)) + + zout = self._mask(z, x) + if self.use_train_segment: + if self.training: + x = self._ispec(zout, length) + else: + x = self._ispec(zout, training_length) + else: + x = self._ispec(zout, length) + + if self.use_train_segment: + if self.training: + xt = xt.view(B, S, -1, length) + else: + xt = xt.view(B, S, -1, training_length) + else: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + if length_pre_pad: + x = x[..., :length_pre_pad] + return x + + +def get_model(args): + extra = { + 'sources': list(args.training.instruments), + 'audio_channels': args.training.channels, + 'samplerate': args.training.samplerate, + # 'segment': args.model_segment or 4 * args.dset.segment, + 'segment': args.training.segment, + } + klass = { + 'demucs': Demucs, + 'hdemucs': HDemucs, + 'htdemucs': HTDemucs, + }[args.model] + kw = OmegaConf.to_container(getattr(args, args.model), resolve=True) + model = klass(**extra, **kw) + return model + + diff --git a/models/mdx23c_tfc_tdf_v3.py b/models/mdx23c_tfc_tdf_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..d0a5c83fa0ca4c0cc278d0954c8e08f73c6f07c7 --- /dev/null +++ b/models/mdx23c_tfc_tdf_v3.py @@ -0,0 +1,242 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + + +class STFT: + def __init__(self, config): + self.n_fft = config.n_fft + self.hop_length = config.hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = config.dim_f + + def __call__(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + return x[..., :self.dim_f, :] + + def inverse(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.j + x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True) + x = x.reshape([*batch_dims, 2, -1]) + return x + + +def get_norm(norm_type): + def norm(c, norm_type): + if norm_type == 'BatchNorm': + return nn.BatchNorm2d(c) + elif norm_type == 'InstanceNorm': + return nn.InstanceNorm2d(c, affine=True) + elif 'GroupNorm' in norm_type: + g = int(norm_type.replace('GroupNorm', '')) + return nn.GroupNorm(num_groups=g, num_channels=c) + else: + return nn.Identity() + + return partial(norm, norm_type=norm_type) + + +def get_act(act_type): + if act_type == 'gelu': + return nn.GELU() + elif act_type == 'relu': + return nn.ReLU() + elif act_type[:3] == 'elu': + alpha = float(act_type.replace('elu', '')) + return nn.ELU(alpha) + else: + raise Exception + + +class Upscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + ) + + def forward(self, x): + return self.conv(x) + + +class Downscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + ) + + def forward(self, x): + return self.conv(x) + + +class TFC_TDF(nn.Module): + def __init__(self, in_c, c, l, f, bn, norm, act): + super().__init__() + + self.blocks = nn.ModuleList() + for i in range(l): + block = nn.Module() + + block.tfc1 = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_c, c, 3, 1, 1, bias=False), + ) + block.tdf = nn.Sequential( + norm(c), + act, + nn.Linear(f, f // bn, bias=False), + norm(c), + act, + nn.Linear(f // bn, f, bias=False), + ) + block.tfc2 = nn.Sequential( + norm(c), + act, + nn.Conv2d(c, c, 3, 1, 1, bias=False), + ) + block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False) + + self.blocks.append(block) + in_c = c + + def forward(self, x): + for block in self.blocks: + s = block.shortcut(x) + x = block.tfc1(x) + x = x + block.tdf(x) + x = block.tfc2(x) + x = x + s + return x + + +class TFC_TDF_net(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + norm = get_norm(norm_type=config.model.norm) + act = get_act(act_type=config.model.act) + + self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) + self.num_subbands = config.model.num_subbands + + dim_c = self.num_subbands * config.audio.num_channels * 2 + n = config.model.num_scales + scale = config.model.scale + l = config.model.num_blocks_per_scale + c = config.model.num_channels + g = config.model.growth + bn = config.model.bottleneck_factor + f = config.audio.dim_f // self.num_subbands + + self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) + + self.encoder_blocks = nn.ModuleList() + for i in range(n): + block = nn.Module() + block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act) + block.downscale = Downscale(c, c + g, scale, norm, act) + f = f // scale[1] + c += g + self.encoder_blocks.append(block) + + self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act) + + self.decoder_blocks = nn.ModuleList() + for i in range(n): + block = nn.Module() + block.upscale = Upscale(c, c - g, scale, norm, act) + f = f * scale[1] + c -= g + block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act) + self.decoder_blocks.append(block) + + self.final_conv = nn.Sequential( + nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), + act, + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + ) + + self.stft = STFT(config.audio) + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, x): + + x = self.stft(x) + + mix = x = self.cac2cws(x) + + first_conv_out = x = self.first_conv(x) + + x = x.transpose(-1, -2) + + encoder_outputs = [] + for block in self.encoder_blocks: + x = block.tfc_tdf(x) + encoder_outputs.append(x) + x = block.downscale(x) + + x = self.bottleneck_block(x) + + for block in self.decoder_blocks: + x = block.upscale(x) + x = torch.cat([x, encoder_outputs.pop()], 1) + x = block.tfc_tdf(x) + + x = x.transpose(-1, -2) + + x = x * first_conv_out # reduce artifacts + + x = self.final_conv(torch.cat([mix, x], 1)) + + x = self.cws2cac(x) + + if self.num_target_instruments > 1: + b, c, f, t = x.shape + x = x.reshape(b, self.num_target_instruments, -1, f, t) + + x = self.stft.inverse(x) + + return x diff --git a/models/scnet/__init__.py b/models/scnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70f5b2c48a420dd2dec41a778af3b8a211364d93 --- /dev/null +++ b/models/scnet/__init__.py @@ -0,0 +1 @@ +from .scnet import SCNet diff --git a/models/scnet/scnet.py b/models/scnet/scnet.py new file mode 100644 index 0000000000000000000000000000000000000000..32423cd8d1e10d792a55b477f159e4d7f5fc4960 --- /dev/null +++ b/models/scnet/scnet.py @@ -0,0 +1,373 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import deque +from .separation import SeparationNet +import typing as tp +import math + +class Swish(nn.Module): + def forward(self, x): + return x * x.sigmoid() + + +class ConvolutionModule(nn.Module): + """ + Convolution Module in SD block. + + Args: + channels (int): input/output channels. + depth (int): number of layers in the residual branch. Each layer has its own + compress (float): amount of channel compression. + kernel (int): kernel size for the convolutions. + """ + def __init__(self, channels, depth=2, compress=4, kernel=3): + super().__init__() + assert kernel % 2 == 1 + self.depth = abs(depth) + hidden_size = int(channels / compress) + norm = lambda d: nn.GroupNorm(1, d) + self.layers = nn.ModuleList([]) + for _ in range(self.depth): + padding = (kernel // 2) + mods = [ + norm(channels), + nn.Conv1d(channels, hidden_size*2, kernel, padding = padding), + nn.GLU(1), + nn.Conv1d(hidden_size, hidden_size, kernel, padding = padding, groups = hidden_size), + norm(hidden_size), + Swish(), + nn.Conv1d(hidden_size, channels, 1), + ] + layer = nn.Sequential(*mods) + self.layers.append(layer) + + def forward(self, x): + for layer in self.layers: + x = x + layer(x) + return x + + +class FusionLayer(nn.Module): + """ + A FusionLayer within the decoder. + + Args: + - channels (int): Number of input channels. + - kernel_size (int, optional): Kernel size for the convolutional layer, defaults to 3. + - stride (int, optional): Stride for the convolutional layer, defaults to 1. + - padding (int, optional): Padding for the convolutional layer, defaults to 1. + """ + + def __init__(self, channels, kernel_size=3, stride=1, padding=1): + super(FusionLayer, self).__init__() + self.conv = nn.Conv2d(channels * 2, channels * 2, kernel_size, stride=stride, padding=padding) + + def forward(self, x, skip=None): + if skip is not None: + x += skip + x = x.repeat(1, 2, 1, 1) + x = self.conv(x) + x = F.glu(x, dim=1) + return x + + +class SDlayer(nn.Module): + """ + Implements a Sparse Down-sample Layer for processing different frequency bands separately. + + Args: + - channels_in (int): Input channel count. + - channels_out (int): Output channel count. + - band_configs (dict): A dictionary containing configuration for each frequency band. + Keys are 'low', 'mid', 'high' for each band, and values are + dictionaries with keys 'SR', 'stride', and 'kernel' for proportion, + stride, and kernel size, respectively. + """ + def __init__(self, channels_in, channels_out, band_configs): + super(SDlayer, self).__init__() + + # Initializing convolutional layers for each band + self.convs = nn.ModuleList() + self.strides = [] + self.kernels = [] + for config in band_configs.values(): + self.convs.append(nn.Conv2d(channels_in, channels_out, (config['kernel'], 1), (config['stride'], 1), (0, 0))) + self.strides.append(config['stride']) + self.kernels.append(config['kernel']) + + # Saving rate proportions for determining splits + self.SR_low = band_configs['low']['SR'] + self.SR_mid = band_configs['mid']['SR'] + + def forward(self, x): + B, C, Fr, T = x.shape + # Define splitting points based on sampling rates + splits = [ + (0, math.ceil(Fr * self.SR_low)), + (math.ceil(Fr * self.SR_low), math.ceil(Fr * (self.SR_low + self.SR_mid))), + (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr) + ] + + # Processing each band with the corresponding convolution + outputs = [] + original_lengths=[] + for conv, stride, kernel, (start, end) in zip(self.convs, self.strides, self.kernels, splits): + extracted = x[:, :, start:end, :] + original_lengths.append(end-start) + current_length = extracted.shape[2] + + # padding + if stride == 1: + total_padding = kernel - stride + else: + total_padding = (stride - current_length % stride) % stride + pad_left = total_padding // 2 + pad_right = total_padding - pad_left + + padded = F.pad(extracted, (0, 0, pad_left, pad_right)) + + output = conv(padded) + outputs.append(output) + + return outputs, original_lengths + + +class SUlayer(nn.Module): + """ + Implements a Sparse Up-sample Layer in decoder. + + Args: + - channels_in: The number of input channels. + - channels_out: The number of output channels. + - convtr_configs: Dictionary containing the configurations for transposed convolutions. + """ + def __init__(self, channels_in, channels_out, band_configs): + super(SUlayer, self).__init__() + + # Initializing convolutional layers for each band + self.convtrs = nn.ModuleList([ + nn.ConvTranspose2d(channels_in, channels_out, [config['kernel'], 1], [config['stride'], 1]) + for _, config in band_configs.items() + ]) + + def forward(self, x, lengths, origin_lengths): + B, C, Fr, T = x.shape + # Define splitting points based on input lengths + splits = [ + (0, lengths[0]), + (lengths[0], lengths[0] + lengths[1]), + (lengths[0] + lengths[1], None) + ] + # Processing each band with the corresponding convolution + outputs = [] + for idx, (convtr, (start, end)) in enumerate(zip(self.convtrs, splits)): + out = convtr(x[:, :, start:end, :]) + # Calculate the distance to trim the output symmetrically to original length + current_Fr_length = out.shape[2] + dist = abs(origin_lengths[idx] - current_Fr_length) // 2 + + # Trim the output to the original length symmetrically + trimmed_out = out[:, :, dist:dist + origin_lengths[idx], :] + + outputs.append(trimmed_out) + + # Concatenate trimmed outputs along the frequency dimension to return the final tensor + x = torch.cat(outputs, dim=2) + + return x + + +class SDblock(nn.Module): + """ + Implements a simplified Sparse Down-sample block in encoder. + + Args: + - channels_in (int): Number of input channels. + - channels_out (int): Number of output channels. + - band_config (dict): Configuration for the SDlayer specifying band splits and convolutions. + - conv_config (dict): Configuration for convolution modules applied to each band. + - depths (list of int): List specifying the convolution depths for low, mid, and high frequency bands. + """ + def __init__(self, channels_in, channels_out, band_configs={}, conv_config={}, depths=[3, 2, 1], kernel_size=3): + super(SDblock, self).__init__() + self.SDlayer = SDlayer(channels_in, channels_out, band_configs) + + # Dynamically create convolution modules for each band based on depths + self.conv_modules = nn.ModuleList([ + ConvolutionModule(channels_out, depth, **conv_config) for depth in depths + ]) + #Set the kernel_size to an odd number. + self.globalconv = nn.Conv2d(channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2) + + def forward(self, x): + bands, original_lengths = self.SDlayer(x) + # B, C, f, T = band.shape + bands = [ + F.gelu( + conv(band.permute(0, 2, 1, 3).reshape(-1, band.shape[1], band.shape[3])) + .view(band.shape[0], band.shape[2], band.shape[1], band.shape[3]) + .permute(0, 2, 1, 3) + ) + for conv, band in zip(self.conv_modules, bands) + + ] + lengths = [band.size(-2) for band in bands] + full_band = torch.cat(bands, dim=2) + skip = full_band + + output = self.globalconv(full_band) + + return output, skip, lengths, original_lengths + + +class SCNet(nn.Module): + """ + The implementation of SCNet: Sparse Compression Network for Music Source Separation. Paper: https://arxiv.org/abs/2401.13276.pdf + + Args: + - sources (List[str]): List of sources to be separated. + - audio_channels (int): Number of audio channels. + - nfft (int): Number of FFTs to determine the frequency dimension of the input. + - hop_size (int): Hop size for the STFT. + - win_size (int): Window size for STFT. + - normalized (bool): Whether to normalize the STFT. + - dims (List[int]): List of channel dimensions for each block. + - band_configs (Dict[str, Dict[str, int]]): Configuration for each frequency band, including how to divide the frequency bands, + and the settings for the upsampling/downsampling convolutional layers. + - conv_depths (List[int]): List specifying the number of convolution modules in each SD block. + - compress (int): Compression factor for convolution module. + - conv_kernel (int): Kernel size for convolution layer in convolution module. + - num_dplayer (int): Number of dual-path layers. + - expand (int): Expansion factor in the dual-path RNN, default is 1. + + """ + def __init__(self, + sources = ['drums', 'bass', 'other', 'vocals'], + audio_channels = 2, + # Main structure + dims = [4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large + # STFT + nfft = 4096, + hop_size = 1024, + win_size = 4096, + normalized = True, + # SD/SU layer + band_configs = { + 'low': { 'SR': .175, 'stride': 1, 'kernel': 3 }, + 'mid': { 'SR': .392, 'stride': 4, 'kernel': 4 }, + 'high': {'SR': .433, 'stride': 16, 'kernel': 16 } + }, + # Convolution Module + conv_depths = [3,2,1], + compress = 4, + conv_kernel = 3, + # Dual-path RNN + num_dplayer = 6, + expand = 1, + # mamba + use_mamba = False, + mamba_config = { + 'd_stat': 16, + 'd_conv': 4, + 'd_expand': 2 + }): + super().__init__() + self.sources = sources + self.audio_channels = audio_channels + self.dims = dims + self.band_configs = band_configs + self.hop_length = hop_size + self.conv_config = { + 'compress': compress, + 'kernel': conv_kernel, + } + + self.stft_config = { + 'n_fft': nfft, + 'hop_length': hop_size, + 'win_length': win_size, + 'center': True, + 'normalized': normalized + } + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for index in range(len(dims)-1): + enc = SDblock( + channels_in = dims[index], + channels_out = dims[index+1], + band_configs = self.band_configs, + conv_config = self.conv_config, + depths = conv_depths + ) + self.encoder.append(enc) + + dec = nn.Sequential( + FusionLayer(channels = dims[index+1]), + SUlayer( + channels_in = dims[index+1], + channels_out = dims[index] if index != 0 else dims[index] * len(sources), + band_configs = self.band_configs, + ) + ) + self.decoder.insert(0, dec) + + self.separation_net = SeparationNet( + channels = dims[-1], + expand = expand, + num_layers = num_dplayer, + use_mamba = use_mamba, + **mamba_config + ) + + + def forward(self, x): + # B, C, L = x.shape + B = x.shape[0] + # In the initial padding, ensure that the number of frames after the STFT (the length of the T dimension) is even, + # so that the RFFT operation can be used in the separation network. + padding = self.hop_length - x.shape[-1] % self.hop_length + if (x.shape[-1] + padding) // self.hop_length % 2 == 0: + padding += self.hop_length + x = F.pad(x, (0, padding)) + + # STFT + L = x.shape[-1] + x = x.reshape(-1, L) + x = torch.stft(x, **self.stft_config, return_complex=True) + x = torch.view_as_real(x) + x = x.permute(0, 3, 1, 2).reshape(x.shape[0]//self.audio_channels, x.shape[3]*self.audio_channels, x.shape[1], x.shape[2]) + + B, C, Fr, T = x.shape + + save_skip = deque() + save_lengths = deque() + save_original_lengths = deque() + # encoder + for sd_layer in self.encoder: + x, skip, lengths, original_lengths = sd_layer(x) + save_skip.append(skip) + save_lengths.append(lengths) + save_original_lengths.append(original_lengths) + + #separation + x = self.separation_net(x) + + #decoder + for fusion_layer, su_layer in self.decoder: + x = fusion_layer(x, save_skip.pop()) + x = su_layer(x, save_lengths.pop(), save_original_lengths.pop()) + + #output + n = self.dims[0] + x = x.view(B, n, -1, Fr, T) + x = x.reshape(-1, 2, Fr, T).permute(0, 2, 3, 1) + x = torch.view_as_complex(x.contiguous()) + x = torch.istft(x, **self.stft_config) + x = x.reshape(B, len(self.sources), self.audio_channels, -1) + + x = x[:, :, :, :-padding] + + return x \ No newline at end of file diff --git a/models/scnet/separation.py b/models/scnet/separation.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed5f408d10facbcb0673db3e3ee01fd881518ac --- /dev/null +++ b/models/scnet/separation.py @@ -0,0 +1,178 @@ +import torch +import torch.nn as nn +from torch.nn.modules.rnn import LSTM +import torch.nn.functional as Func +try: + from mamba_ssm.modules.mamba_simple import Mamba +except Exception as e: + print('No mamba found. Please install mamba_ssm') + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return Func.normalize(x, dim=-1) * self.scale * self.gamma + + +class MambaModule(nn.Module): + def __init__(self, d_model, d_state, d_conv, d_expand): + super().__init__() + self.norm = RMSNorm(dim=d_model) + self.mamba = Mamba( + d_model=d_model, + d_state=d_state, + d_conv=d_conv, + expand=d_expand + ) + + def forward(self, x): + x = x + self.mamba(self.norm(x)) + return x + + +class FeatureConversion(nn.Module): + """ + Integrates into the adjacent Dual-Path layer. + + Args: + channels (int): Number of input channels. + inverse (bool): If True, uses ifft; otherwise, uses rfft. + """ + def __init__(self, channels, inverse): + super().__init__() + self.inverse = inverse + self.channels= channels + + def forward(self, x): + # B, C, F, T = x.shape + if self.inverse: + x = x.float() + x_r = x[:, :self.channels//2, :, :] + x_i = x[:, self.channels//2:, :, :] + x = torch.complex(x_r, x_i) + x = torch.fft.irfft(x, dim=3, norm="ortho") + else: + x = x.float() + x = torch.fft.rfft(x, dim=3, norm="ortho") + x_real = x.real + x_imag = x.imag + x = torch.cat([x_real, x_imag], dim=1) + return x + + +class DualPathRNN(nn.Module): + """ + Dual-Path RNN in Separation Network. + + Args: + d_model (int): The number of expected features in the input (input_size). + expand (int): Expansion factor used to calculate the hidden_size of LSTM. + bidirectional (bool): If True, becomes a bidirectional LSTM. + """ + def __init__(self, d_model, expand, bidirectional=True): + super(DualPathRNN, self).__init__() + + self.d_model = d_model + self.hidden_size = d_model * expand + self.bidirectional = bidirectional + # Initialize LSTM layers and normalization layers + self.lstm_layers = nn.ModuleList([self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)]) + self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_size*2, self.d_model) for _ in range(2)]) + self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)]) + + def _init_lstm_layer(self, d_model, hidden_size): + return LSTM(d_model, hidden_size, num_layers=1, bidirectional=self.bidirectional, batch_first=True) + + def forward(self, x): + B, C, F, T = x.shape + + # Process dual-path rnn + + original_x = x + # Frequency-path + x = self.norm_layers[0](x) + x = x.transpose(1, 3).contiguous().view(B * T, F, C) + x, _ = self.lstm_layers[0](x) + x = self.linear_layers[0](x) + x = x.view(B, T, F, C).transpose(1, 3) + x = x + original_x + + original_x = x + # Time-path + x = self.norm_layers[1](x) + x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2) + x, _ = self.lstm_layers[1](x) + x = self.linear_layers[1](x) + x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2) + x = x + original_x + + return x + + +class DualPathMamba(nn.Module): + """ + Dual-Path Mamba. + + """ + def __init__(self, d_model, d_stat, d_conv, d_expand): + super(DualPathMamba, self).__init__() + # Initialize mamba layers + self.mamba_layers = nn.ModuleList([MambaModule(d_model, d_stat, d_conv, d_expand) for _ in range(2)]) + + def forward(self, x): + B, C, F, T = x.shape + + # Process dual-path mamba + + # Frequency-path + x = x.transpose(1, 3).contiguous().view(B * T, F, C) + x = self.mamba_layers[0](x) + x = x.view(B, T, F, C).transpose(1, 3) + + # Time-path + x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2) + x = self.mamba_layers[1](x) + x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2) + + return x + + +class SeparationNet(nn.Module): + """ + Implements a simplified Sparse Down-sample block in an encoder architecture. + + Args: + - channels (int): Number input channels. + - expand (int): Expansion factor used to calculate the hidden_size of LSTM. + - num_layers (int): Number of dual-path layers. + - use_mamba (bool): If true, use the Mamba module to replace the RNN. + - d_stat (int), d_conv (int), d_expand (int): These are built-in parameters of the Mamba model. + """ + def __init__(self, channels, expand=1, num_layers=6, use_mamba=True, d_stat=16, d_conv=4, d_expand=2): + super(SeparationNet, self).__init__() + + self.num_layers = num_layers + if use_mamba: + self.dp_modules = nn.ModuleList([ + DualPathMamba(channels * (2 if i % 2 == 1 else 1), d_stat, d_conv, d_expand * (2 if i % 2 == 1 else 1)) for i in range(num_layers) + ]) + else: + self.dp_modules = nn.ModuleList([ + DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) for i in range(num_layers) + ]) + + self.feature_conversion = nn.ModuleList([ + FeatureConversion(channels * 2 , inverse = False if i % 2 == 0 else True) for i in range(num_layers) + ]) + def forward(self, x): + for i in range(self.num_layers): + x = self.dp_modules[i](x) + x = self.feature_conversion[i](x) + return x + + + + diff --git a/models/scnet_unofficial/__init__.py b/models/scnet_unofficial/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d034d38a2ace2e81bd28d63dd8f25feb918f33d --- /dev/null +++ b/models/scnet_unofficial/__init__.py @@ -0,0 +1 @@ +from models.scnet_unofficial.scnet import SCNet \ No newline at end of file diff --git a/models/scnet_unofficial/modules/__init__.py b/models/scnet_unofficial/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..073826e128ffef8c9c374bc897397ad85bf149a7 --- /dev/null +++ b/models/scnet_unofficial/modules/__init__.py @@ -0,0 +1,3 @@ +from models.scnet_unofficial.modules.dualpath_rnn import DualPathRNN +from models.scnet_unofficial.modules.sd_encoder import SDBlock +from models.scnet_unofficial.modules.su_decoder import SUBlock diff --git a/models/scnet_unofficial/modules/dualpath_rnn.py b/models/scnet_unofficial/modules/dualpath_rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..3534e172b014beaf8b4546ef090a8cda35b97437 --- /dev/null +++ b/models/scnet_unofficial/modules/dualpath_rnn.py @@ -0,0 +1,228 @@ +import torch +import torch.nn as nn +import torch.nn.functional as Func + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return Func.normalize(x, dim=-1) * self.scale * self.gamma + + +class MambaModule(nn.Module): + def __init__(self, d_model, d_state, d_conv, d_expand): + super().__init__() + self.norm = RMSNorm(dim=d_model) + self.mamba = Mamba( + d_model=d_model, + d_state=d_state, + d_conv=d_conv, + d_expand=d_expand + ) + + def forward(self, x): + x = x + self.mamba(self.norm(x)) + return x + + +class RNNModule(nn.Module): + """ + RNNModule class implements a recurrent neural network module with LSTM cells. + + Args: + - input_dim (int): Dimensionality of the input features. + - hidden_dim (int): Dimensionality of the hidden state of the LSTM. + - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True. + + Shapes: + - Input: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + - Output: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + """ + + def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True): + """ + Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag. + """ + super().__init__() + self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim) + self.rnn = nn.LSTM( + input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional + ) + self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the RNNModule. + + Args: + - x (torch.Tensor): Input tensor of shape (B, T, D). + + Returns: + - torch.Tensor: Output tensor of shape (B, T, D). + """ + x = x.transpose(1, 2) + x = self.groupnorm(x) + x = x.transpose(1, 2) + + x, (hidden, _) = self.rnn(x) + x = self.fc(x) + return x + + +class RFFTModule(nn.Module): + """ + RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT) + or its inverse on input tensors. + + Args: + - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False. + + Shapes: + - Input: (B, F, T, D) where + B is batch size, + F is the number of features, + T is sequence length, + D is input dimensionality. + - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT. + (B, F, T, D // 2, 2) if performing inverse FFT. + """ + + def __init__(self, inverse: bool = False): + """ + Initializes RFFTModule with inverse flag. + """ + super().__init__() + self.inverse = inverse + + def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor: + """ + Performs forward or inverse FFT on the input tensor x. + + Args: + - x (torch.Tensor): Input tensor of shape (B, F, T, D). + - time_dim (int): Input size of time dimension. + + Returns: + - torch.Tensor: Output tensor after FFT or its inverse operation. + """ + dtype = x.dtype + B, F, T, D = x.shape + + # RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision + x = x.float() + + if not self.inverse: + x = torch.fft.rfft(x, dim=2) + x = torch.view_as_real(x) + x = x.reshape(B, F, T // 2 + 1, D * 2) + else: + x = x.reshape(B, F, T, D // 2, 2) + x = torch.view_as_complex(x) + x = torch.fft.irfft(x, n=time_dim, dim=2) + + x = x.to(dtype) + return x + + def extra_repr(self) -> str: + """ + Returns extra representation string with module's configuration. + """ + return f"inverse={self.inverse}" + + +class DualPathRNN(nn.Module): + """ + DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule. + + Args: + - n_layers (int): Number of layers in the network. + - input_dim (int): Dimensionality of the input features. + - hidden_dim (int): Dimensionality of the hidden state of the RNNModule. + + Shapes: + - Input: (B, F, T, D) where + B is batch size, + F is the number of features (frequency dimension), + T is sequence length (time dimension), + D is input dimensionality (channel dimension). + - Output: (B, F, T, D) where + B is batch size, + F is the number of features (frequency dimension), + T is sequence length (time dimension), + D is input dimensionality (channel dimension). + """ + + def __init__( + self, + n_layers: int, + input_dim: int, + hidden_dim: int, + + use_mamba: bool = False, + d_state: int = 16, + d_conv: int = 4, + d_expand: int = 2 + ): + """ + Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension. + """ + super().__init__() + + if use_mamba: + from mamba_ssm.modules.mamba_simple import Mamba + net = MambaModule + dkwargs = {"d_model": input_dim, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand} + ukwargs = {"d_model": input_dim * 2, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand * 2} + else: + net = RNNModule + dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim} + ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2} + + self.layers = nn.ModuleList() + for i in range(1, n_layers + 1): + kwargs = dkwargs if i % 2 == 1 else ukwargs + layer = nn.ModuleList([ + net(**kwargs), + net(**kwargs), + RFFTModule(inverse=(i % 2 == 0)), + ]) + self.layers.append(layer) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the DualPathRNN. + + Args: + - x (torch.Tensor): Input tensor of shape (B, F, T, D). + + Returns: + - torch.Tensor: Output tensor of shape (B, F, T, D). + """ + + time_dim = x.shape[2] + + for time_layer, freq_layer, rfft_layer in self.layers: + B, F, T, D = x.shape + + x = x.reshape((B * F), T, D) + x = time_layer(x) + x = x.reshape(B, F, T, D) + x = x.permute(0, 2, 1, 3) + + x = x.reshape((B * T), F, D) + x = freq_layer(x) + x = x.reshape(B, T, F, D) + x = x.permute(0, 2, 1, 3) + + x = rfft_layer(x, time_dim) + + return x diff --git a/models/scnet_unofficial/modules/sd_encoder.py b/models/scnet_unofficial/modules/sd_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..58a586f0cfdaf1e39d3317f965144d0891fd51c9 --- /dev/null +++ b/models/scnet_unofficial/modules/sd_encoder.py @@ -0,0 +1,285 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from models.scnet_unofficial.utils import create_intervals + + +class Downsample(nn.Module): + """ + Downsample class implements a module for downsampling input tensors using 2D convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - stride (int): Stride value for the convolution operation. + + Shapes: + - Input: (B, C_in, F, T) where + B is batch size, + C_in is the number of input channels, + F is the frequency dimension, + T is the time dimension. + - Output: (B, C_out, F // stride, T) where + B is batch size, + C_out is the number of output channels, + F // stride is the downsampled frequency dimension. + + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + stride: int, + ): + """ + Initializes Downsample with input dimension, output dimension, and stride. + """ + super().__init__() + self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the Downsample module. + + Args: + - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). + + Returns: + - torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T). + """ + return self.conv(x) + + +class ConvolutionModule(nn.Module): + """ + ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer. + + Args: + - input_dim (int): Dimensionality of the input features. + - hidden_dim (int): Dimensionality of the hidden features. + - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. + - bias (bool, optional): If True, adds a learnable bias to the output. Default is False. + + Shapes: + - Input: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + - Output: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + kernel_sizes: List[int], + bias: bool = False, + ) -> None: + """ + Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias. + """ + super().__init__() + self.sequential = nn.Sequential( + nn.GroupNorm(num_groups=1, num_channels=input_dim), + nn.Conv1d( + input_dim, + 2 * hidden_dim, + kernel_sizes[0], + stride=1, + padding=(kernel_sizes[0] - 1) // 2, + bias=bias, + ), + nn.GLU(dim=1), + nn.Conv1d( + hidden_dim, + hidden_dim, + kernel_sizes[1], + stride=1, + padding=(kernel_sizes[1] - 1) // 2, + groups=hidden_dim, + bias=bias, + ), + nn.GroupNorm(num_groups=1, num_channels=hidden_dim), + nn.SiLU(), + nn.Conv1d( + hidden_dim, + input_dim, + kernel_sizes[2], + stride=1, + padding=(kernel_sizes[2] - 1) // 2, + bias=bias, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the ConvolutionModule. + + Args: + - x (torch.Tensor): Input tensor of shape (B, T, D). + + Returns: + - torch.Tensor: Output tensor of shape (B, T, D). + """ + x = x.transpose(1, 2) + x = x + self.sequential(x) + x = x.transpose(1, 2) + return x + + +class SDLayer(nn.Module): + """ + SDLayer class implements a subband decomposition layer with downsampling and convolutional modules. + + Args: + - subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition. + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels after downsampling. + - downsample_stride (int): Stride value for the downsampling operation. + - n_conv_modules (int): Number of convolutional modules. + - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. + - bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True. + + Shapes: + - Input: (B, Fi, T, Ci) where + B is batch size, + Fi is the number of input subbands, + T is sequence length, and + Ci is the number of input channels. + - Output: (B, Fi+1, T, Ci+1) where + B is batch size, + Fi+1 is the number of output subbands, + T is sequence length, + Ci+1 is the number of output channels. + """ + + def __init__( + self, + subband_interval: Tuple[float, float], + input_dim: int, + output_dim: int, + downsample_stride: int, + n_conv_modules: int, + kernel_sizes: List[int], + bias: bool = True, + ): + """ + Initializes SDLayer with subband interval, input dimension, + output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias. + """ + super().__init__() + self.subband_interval = subband_interval + self.downsample = Downsample(input_dim, output_dim, downsample_stride) + self.activation = nn.GELU() + conv_modules = [ + ConvolutionModule( + input_dim=output_dim, + hidden_dim=output_dim // 4, + kernel_sizes=kernel_sizes, + bias=bias, + ) + for _ in range(n_conv_modules) + ] + self.conv_modules = nn.Sequential(*conv_modules) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SDLayer. + + Args: + - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). + + Returns: + - torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1). + """ + B, F, T, C = x.shape + x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)] + x = x.permute(0, 3, 1, 2) + x = self.downsample(x) + x = self.activation(x) + x = x.permute(0, 2, 3, 1) + + B, F, T, C = x.shape + x = x.reshape((B * F), T, C) + x = self.conv_modules(x) + x = x.reshape(B, F, T, C) + + return x + + +class SDBlock(nn.Module): + """ + SDBlock class implements a block with subband decomposition layers and global convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. + - downsample_strides (List[int]): List of stride values for downsampling in each subband layer. + - n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer. + - kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None. + + Shapes: + - Input: (B, Fi, T, Ci) where + B is batch size, + Fi is the number of input subbands, + T is sequence length, + Ci is the number of input channels. + - Output: (B, Fi+1, T, Ci+1) where + B is batch size, + Fi+1 is the number of output subbands, + T is sequence length, + Ci+1 is the number of output channels. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + bandsplit_ratios: List[float], + downsample_strides: List[int], + n_conv_modules: List[int], + kernel_sizes: List[int] = None, + ): + """ + Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes. + """ + super().__init__() + if kernel_sizes is None: + kernel_sizes = [3, 3, 1] + assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1." + subband_intervals = create_intervals(bandsplit_ratios) + self.sd_layers = nn.ModuleList( + SDLayer( + input_dim=input_dim, + output_dim=output_dim, + subband_interval=sbi, + downsample_stride=dss, + n_conv_modules=ncm, + kernel_sizes=kernel_sizes, + ) + for sbi, dss, ncm in zip( + subband_intervals, downsample_strides, n_conv_modules + ) + ) + self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Performs forward pass through the SDBlock. + + Args: + - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). + + Returns: + - Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor. + """ + x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1) + x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + return x, x_skip diff --git a/models/scnet_unofficial/modules/su_decoder.py b/models/scnet_unofficial/modules/su_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6f0909a4573fab7336a1a6d1272acfec0f8ad375 --- /dev/null +++ b/models/scnet_unofficial/modules/su_decoder.py @@ -0,0 +1,241 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from models.scnet_unofficial.utils import get_convtranspose_output_padding + + +class FusionLayer(nn.Module): + """ + FusionLayer class implements a module for fusing two input tensors using convolutional operations. + + Args: + - input_dim (int): Dimensionality of the input channels. + - kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3. + - stride (int, optional): Stride value for the convolutional layer. Default is 1. + - padding (int, optional): Padding value for the convolutional layer. Default is 1. + + Shapes: + - Input: (B, F, T, C) and (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + - Output: (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + """ + + def __init__( + self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1 + ): + """ + Initializes FusionLayer with input dimension, kernel size, stride, and padding. + """ + super().__init__() + self.conv = nn.Conv2d( + input_dim * 2, + input_dim * 2, + kernel_size=(kernel_size, 1), + stride=(stride, 1), + padding=(padding, 0), + ) + self.activation = nn.GLU() + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the FusionLayer. + + Args: + - x1 (torch.Tensor): First input tensor of shape (B, F, T, C). + - x2 (torch.Tensor): Second input tensor of shape (B, F, T, C). + + Returns: + - torch.Tensor: Output tensor of shape (B, F, T, C). + """ + x = x1 + x2 + x = x.repeat(1, 1, 1, 2) + x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + x = self.activation(x) + return x + + +class Upsample(nn.Module): + """ + Upsample class implements a module for upsampling input tensors using transposed 2D convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - stride (int): Stride value for the transposed convolution operation. + - output_padding (int): Output padding value for the transposed convolution operation. + + Shapes: + - Input: (B, C_in, F, T) where + B is batch size, + C_in is the number of input channels, + F is the frequency dimension, + T is the time dimension. + - Output: (B, C_out, F * stride + output_padding, T) where + B is batch size, + C_out is the number of output channels, + F * stride + output_padding is the upsampled frequency dimension. + """ + + def __init__( + self, input_dim: int, output_dim: int, stride: int, output_padding: int + ): + """ + Initializes Upsample with input dimension, output dimension, stride, and output padding. + """ + super().__init__() + self.conv = nn.ConvTranspose2d( + input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the Upsample module. + + Args: + - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). + + Returns: + - torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T). + """ + return self.conv(x) + + +class SULayer(nn.Module): + """ + SULayer class implements a subband upsampling layer using transposed convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - upsample_stride (int): Stride value for the upsampling operation. + - subband_shape (int): Shape of the subband. + - sd_interval (Tuple[int, int]): Start and end indices of the subband interval. + + Shapes: + - Input: (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + - Output: (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + upsample_stride: int, + subband_shape: int, + sd_interval: Tuple[int, int], + ): + """ + Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval. + """ + super().__init__() + sd_shape = sd_interval[1] - sd_interval[0] + upsample_output_padding = get_convtranspose_output_padding( + input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride + ) + self.upsample = Upsample( + input_dim=input_dim, + output_dim=output_dim, + stride=upsample_stride, + output_padding=upsample_output_padding, + ) + self.sd_interval = sd_interval + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SULayer. + + Args: + - x (torch.Tensor): Input tensor of shape (B, F, T, C). + + Returns: + - torch.Tensor: Output tensor of shape (B, F, T, C). + """ + x = x[:, self.sd_interval[0] : self.sd_interval[1]] + x = x.permute(0, 3, 1, 2) + x = self.upsample(x) + x = x.permute(0, 2, 3, 1) + return x + + +class SUBlock(nn.Module): + """ + SUBlock class implements a block with fusion layer and subband upsampling layers. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - upsample_strides (List[int]): List of stride values for the upsampling operations. + - subband_shapes (List[int]): List of shapes for the subbands. + - sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition. + + Shapes: + - Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where + B is batch size, + Fi-1 is the number of input subbands, + T is sequence length, + Ci-1 is the number of input channels. + - Output: (B, Fi, T, Ci) where + B is batch size, + Fi is the number of output subbands, + T is sequence length, + Ci is the number of output channels. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + upsample_strides: List[int], + subband_shapes: List[int], + sd_intervals: List[Tuple[int, int]], + ): + """ + Initializes SUBlock with input dimension, output dimension, + upsample strides, subband shapes, and subband intervals. + """ + super().__init__() + self.fusion_layer = FusionLayer(input_dim=input_dim) + self.su_layers = nn.ModuleList( + SULayer( + input_dim=input_dim, + output_dim=output_dim, + upsample_stride=uss, + subband_shape=sbs, + sd_interval=sdi, + ) + for i, (uss, sbs, sdi) in enumerate( + zip(upsample_strides, subband_shapes, sd_intervals) + ) + ) + + def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SUBlock. + + Args: + - x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1). + - x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1). + + Returns: + - torch.Tensor: Output tensor of shape (B, Fi, T, Ci). + """ + x = self.fusion_layer(x, x_skip) + x = torch.concat([layer(x) for layer in self.su_layers], dim=1) + return x diff --git a/models/scnet_unofficial/scnet.py b/models/scnet_unofficial/scnet.py new file mode 100644 index 0000000000000000000000000000000000000000..771a807dfa4fd51349c0ed21477e0eaf3c58bae1 --- /dev/null +++ b/models/scnet_unofficial/scnet.py @@ -0,0 +1,249 @@ +''' +SCNet - great paper, great implementation +https://arxiv.org/pdf/2401.13276.pdf +https://github.com/amanteur/SCNet-PyTorch +''' + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from models.scnet_unofficial.modules import DualPathRNN, SDBlock, SUBlock +from models.scnet_unofficial.utils import compute_sd_layer_shapes, compute_gcr + +from einops import rearrange, pack, unpack +from functools import partial + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +class BandSplit(nn.Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +class SCNet(nn.Module): + """ + SCNet class implements a source separation network, + which explicitly split the spectrogram of the mixture into several subbands + and introduce a sparsity-based encoder to model different frequency bands. + + Paper: "SCNET: SPARSE COMPRESSION NETWORK FOR MUSIC SOURCE SEPARATION" + Authors: Weinan Tong, Jiaxu Zhu et al. + Link: https://arxiv.org/abs/2401.13276.pdf + + Args: + - n_fft (int): Number of FFTs to determine the frequency dimension of the input. + - dims (List[int]): List of channel dimensions for each block. + - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. + - downsample_strides (List[int]): List of stride values for downsampling in each block. + - n_conv_modules (List[int]): List specifying the number of convolutional modules in each block. + - n_rnn_layers (int): Number of recurrent layers in the dual path RNN. + - rnn_hidden_dim (int): Dimensionality of the hidden state in the dual path RNN. + - n_sources (int, optional): Number of sources to be separated. Default is 4. + + Shapes: + - Input: (B, C, T) where + B is batch size, + C is channel dim (mono / stereo), + T is time dim + - Output: (B, N, C, T) where + B is batch size, + N is the number of sources. + C is channel dim (mono / stereo), + T is sequence length, + """ + @beartype + def __init__( + self, + n_fft: int, + dims: List[int], + bandsplit_ratios: List[float], + downsample_strides: List[int], + n_conv_modules: List[int], + n_rnn_layers: int, + rnn_hidden_dim: int, + n_sources: int = 4, + hop_length: int = 1024, + win_length: int = 4096, + stft_window_fn: Optional[Callable] = None, + stft_normalized: bool = False, + **kwargs + ): + """ + Initializes SCNet with input parameters. + """ + super().__init__() + self.assert_input_data( + bandsplit_ratios, + downsample_strides, + n_conv_modules, + ) + + n_blocks = len(dims) - 1 + n_freq_bins = n_fft // 2 + 1 + subband_shapes, sd_intervals = compute_sd_layer_shapes( + input_shape=n_freq_bins, + bandsplit_ratios=bandsplit_ratios, + downsample_strides=downsample_strides, + n_layers=n_blocks, + ) + self.sd_blocks = nn.ModuleList( + SDBlock( + input_dim=dims[i], + output_dim=dims[i + 1], + bandsplit_ratios=bandsplit_ratios, + downsample_strides=downsample_strides, + n_conv_modules=n_conv_modules, + ) + for i in range(n_blocks) + ) + self.dualpath_blocks = DualPathRNN( + n_layers=n_rnn_layers, + input_dim=dims[-1], + hidden_dim=rnn_hidden_dim, + **kwargs + ) + self.su_blocks = nn.ModuleList( + SUBlock( + input_dim=dims[i + 1], + output_dim=dims[i] if i != 0 else dims[i] * n_sources, + subband_shapes=subband_shapes[i], + sd_intervals=sd_intervals[i], + upsample_strides=downsample_strides, + ) + for i in reversed(range(n_blocks)) + ) + self.gcr = compute_gcr(subband_shapes) + + self.stft_kwargs = dict( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + normalized=stft_normalized + ) + + self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), win_length) + self.n_sources = n_sources + self.hop_length = hop_length + + @staticmethod + def assert_input_data(*args): + """ + Asserts that the shapes of input features are equal. + """ + for arg1 in args: + for arg2 in args: + if len(arg1) != len(arg2): + raise ValueError( + f"Shapes of input features {arg1} and {arg2} are not equal." + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SCNet. + + Args: + - x (torch.Tensor): Input tensor of shape (B, C, T). + + Returns: + - torch.Tensor: Output tensor of shape (B, N, C, T). + """ + + device = x.device + stft_window = self.stft_window_fn(device=device) + + if x.ndim == 2: + x = rearrange(x, 'b t -> b 1 t') + + c = x.shape[1] + + stft_pad = self.hop_length - x.shape[-1] % self.hop_length + x = F.pad(x, (0, stft_pad)) + + # stft + x, ps = pack_one(x, '* t') + x = torch.stft(x, **self.stft_kwargs, window=stft_window, return_complex=True) + x = torch.view_as_real(x) + x = unpack_one(x, ps, '* c f t') + x = rearrange(x, 'b c f t r -> b f t (c r)') + + # encoder part + x_skips = [] + for sd_block in self.sd_blocks: + x, x_skip = sd_block(x) + x_skips.append(x_skip) + + # separation part + x = self.dualpath_blocks(x) + + # decoder part + for su_block, x_skip in zip(self.su_blocks, reversed(x_skips)): + x = su_block(x, x_skip) + + # istft + x = rearrange(x, 'b f t (c r n) -> b n c f t r', c=c, n=self.n_sources, r=2) + x = x.contiguous() + + x = torch.view_as_complex(x) + x = rearrange(x, 'b n c f t -> (b n c) f t') + x = torch.istft(x, **self.stft_kwargs, window=stft_window, return_complex=False) + x = rearrange(x, '(b n c) t -> b n c t', c=c, n=self.n_sources) + + x = x[..., :-stft_pad] + + return x diff --git a/models/scnet_unofficial/utils.py b/models/scnet_unofficial/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e8704caae380bc4cccd89e81d4fa025202adc886 --- /dev/null +++ b/models/scnet_unofficial/utils.py @@ -0,0 +1,135 @@ +''' +SCNet - great paper, great implementation +https://arxiv.org/pdf/2401.13276.pdf +https://github.com/amanteur/SCNet-PyTorch +''' + +from typing import List, Tuple, Union + +import torch + + +def create_intervals( + splits: List[Union[float, int]] +) -> List[Union[Tuple[float, float], Tuple[int, int]]]: + """ + Create intervals based on splits provided. + + Args: + - splits (List[Union[float, int]]): List of floats or integers representing splits. + + Returns: + - List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals. + """ + start = 0 + return [(start, start := start + split) for split in splits] + + +def get_conv_output_shape( + input_shape: int, + kernel_size: int = 1, + padding: int = 0, + dilation: int = 1, + stride: int = 1, +) -> int: + """ + Compute the output shape of a convolutional layer. + + Args: + - input_shape (int): Input shape. + - kernel_size (int, optional): Kernel size of the convolution. Default is 1. + - padding (int, optional): Padding size. Default is 0. + - dilation (int, optional): Dilation factor. Default is 1. + - stride (int, optional): Stride value. Default is 1. + + Returns: + - int: Output shape. + """ + return int( + (input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 + ) + + +def get_convtranspose_output_padding( + input_shape: int, + output_shape: int, + kernel_size: int = 1, + padding: int = 0, + dilation: int = 1, + stride: int = 1, +) -> int: + """ + Compute the output padding for a convolution transpose operation. + + Args: + - input_shape (int): Input shape. + - output_shape (int): Desired output shape. + - kernel_size (int, optional): Kernel size of the convolution. Default is 1. + - padding (int, optional): Padding size. Default is 0. + - dilation (int, optional): Dilation factor. Default is 1. + - stride (int, optional): Stride value. Default is 1. + + Returns: + - int: Output padding. + """ + return ( + output_shape + - (input_shape - 1) * stride + + 2 * padding + - dilation * (kernel_size - 1) + - 1 + ) + + +def compute_sd_layer_shapes( + input_shape: int, + bandsplit_ratios: List[float], + downsample_strides: List[int], + n_layers: int, +) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]: + """ + Compute the shapes for the subband layers. + + Args: + - input_shape (int): Input shape. + - bandsplit_ratios (List[float]): Ratios for splitting the frequency bands. + - downsample_strides (List[int]): Strides for downsampling in each layer. + - n_layers (int): Number of layers. + + Returns: + - Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes. + """ + bandsplit_shapes_list = [] + conv2d_shapes_list = [] + for _ in range(n_layers): + bandsplit_intervals = create_intervals(bandsplit_ratios) + bandsplit_shapes = [ + int(right * input_shape) - int(left * input_shape) + for left, right in bandsplit_intervals + ] + conv2d_shapes = [ + get_conv_output_shape(bs, stride=ds) + for bs, ds in zip(bandsplit_shapes, downsample_strides) + ] + input_shape = sum(conv2d_shapes) + bandsplit_shapes_list.append(bandsplit_shapes) + conv2d_shapes_list.append(create_intervals(conv2d_shapes)) + + return bandsplit_shapes_list, conv2d_shapes_list + + +def compute_gcr(subband_shapes: List[List[int]]) -> float: + """ + Compute the global compression ratio. + + Args: + - subband_shapes (List[List[int]]): List of subband shapes. + + Returns: + - float: Global compression ratio. + """ + t = torch.Tensor(subband_shapes) + gcr = torch.stack( + [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)] + ).mean() + return float(gcr) \ No newline at end of file diff --git a/models/segm_models.py b/models/segm_models.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf99255af40955c6eb4b977031e252fe01a1986 --- /dev/null +++ b/models/segm_models.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn +import segmentation_models_pytorch as smp + + +class STFT: + def __init__(self, config): + self.n_fft = config.n_fft + self.hop_length = config.hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = config.dim_f + + def __call__(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + return x[..., :self.dim_f, :] + + def inverse(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.j + x = torch.istft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True + ) + x = x.reshape([*batch_dims, 2, -1]) + return x + + +def get_act(act_type): + if act_type == 'gelu': + return nn.GELU() + elif act_type == 'relu': + return nn.ReLU() + elif act_type[:3] == 'elu': + alpha = float(act_type.replace('elu', '')) + return nn.ELU(alpha) + else: + raise Exception + + +def get_decoder(config, c): + decoder = None + decoder_options = dict() + if config.model.decoder_type == 'unet': + try: + decoder_options = dict(config.decoder_unet) + except: + pass + decoder = smp.Unet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'fpn': + try: + decoder_options = dict(config.decoder_fpn) + except: + pass + decoder = smp.FPN( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'unet++': + try: + decoder_options = dict(config.decoder_unet_plus_plus) + except: + pass + decoder = smp.UnetPlusPlus( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'manet': + try: + decoder_options = dict(config.decoder_manet) + except: + pass + decoder = smp.MAnet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'linknet': + try: + decoder_options = dict(config.decoder_linknet) + except: + pass + decoder = smp.Linknet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pspnet': + try: + decoder_options = dict(config.decoder_pspnet) + except: + pass + decoder = smp.PSPNet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pspnet': + try: + decoder_options = dict(config.decoder_pspnet) + except: + pass + decoder = smp.PSPNet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pan': + try: + decoder_options = dict(config.decoder_pan) + except: + pass + decoder = smp.PAN( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'deeplabv3': + try: + decoder_options = dict(config.decoder_deeplabv3) + except: + pass + decoder = smp.DeepLabV3( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'deeplabv3plus': + try: + decoder_options = dict(config.decoder_deeplabv3plus) + except: + pass + decoder = smp.DeepLabV3Plus( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + return decoder + + +class Segm_Models_Net(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + act = get_act(act_type=config.model.act) + + self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) + self.num_subbands = config.model.num_subbands + + dim_c = self.num_subbands * config.audio.num_channels * 2 + c = config.model.num_channels + f = config.audio.dim_f // self.num_subbands + + self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) + + self.unet_model = get_decoder(config, c) + + self.final_conv = nn.Sequential( + nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), + act, + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + ) + + self.stft = STFT(config.audio) + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, x): + + x = self.stft(x) + + mix = x = self.cac2cws(x) + + first_conv_out = x = self.first_conv(x) + + x = x.transpose(-1, -2) + + x = self.unet_model(x) + + x = x.transpose(-1, -2) + + x = x * first_conv_out # reduce artifacts + + x = self.final_conv(torch.cat([mix, x], 1)) + + x = self.cws2cac(x) + + if self.num_target_instruments > 1: + b, c, f, t = x.shape + x = x.reshape(b, self.num_target_instruments, -1, f, t) + + x = self.stft.inverse(x) + return x diff --git a/models/torchseg_models.py b/models/torchseg_models.py new file mode 100644 index 0000000000000000000000000000000000000000..2d81cea639f25dc4c1fa60c25100519ab7586bd5 --- /dev/null +++ b/models/torchseg_models.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn +import torchseg as smp + + +class STFT: + def __init__(self, config): + self.n_fft = config.n_fft + self.hop_length = config.hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = config.dim_f + + def __call__(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + return x[..., :self.dim_f, :] + + def inverse(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.j + x = torch.istft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True + ) + x = x.reshape([*batch_dims, 2, -1]) + return x + + +def get_act(act_type): + if act_type == 'gelu': + return nn.GELU() + elif act_type == 'relu': + return nn.ReLU() + elif act_type[:3] == 'elu': + alpha = float(act_type.replace('elu', '')) + return nn.ELU(alpha) + else: + raise Exception + + +def get_decoder(config, c): + decoder = None + decoder_options = dict() + if config.model.decoder_type == 'unet': + try: + decoder_options = dict(config.decoder_unet) + except: + pass + decoder = smp.Unet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'fpn': + try: + decoder_options = dict(config.decoder_fpn) + except: + pass + decoder = smp.FPN( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'unet++': + try: + decoder_options = dict(config.decoder_unet_plus_plus) + except: + pass + decoder = smp.UnetPlusPlus( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'manet': + try: + decoder_options = dict(config.decoder_manet) + except: + pass + decoder = smp.MAnet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'linknet': + try: + decoder_options = dict(config.decoder_linknet) + except: + pass + decoder = smp.Linknet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pspnet': + try: + decoder_options = dict(config.decoder_pspnet) + except: + pass + decoder = smp.PSPNet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pspnet': + try: + decoder_options = dict(config.decoder_pspnet) + except: + pass + decoder = smp.PSPNet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pan': + try: + decoder_options = dict(config.decoder_pan) + except: + pass + decoder = smp.PAN( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'deeplabv3': + try: + decoder_options = dict(config.decoder_deeplabv3) + except: + pass + decoder = smp.DeepLabV3( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'deeplabv3plus': + try: + decoder_options = dict(config.decoder_deeplabv3plus) + except: + pass + decoder = smp.DeepLabV3Plus( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + return decoder + + +class Torchseg_Net(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + act = get_act(act_type=config.model.act) + + self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) + self.num_subbands = config.model.num_subbands + + dim_c = self.num_subbands * config.audio.num_channels * 2 + c = config.model.num_channels + f = config.audio.dim_f // self.num_subbands + + self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) + + self.unet_model = get_decoder(config, c) + + self.final_conv = nn.Sequential( + nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), + act, + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + ) + + self.stft = STFT(config.audio) + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, x): + + x = self.stft(x) + + mix = x = self.cac2cws(x) + + first_conv_out = x = self.first_conv(x) + + x = x.transpose(-1, -2) + + x = self.unet_model(x) + + x = x.transpose(-1, -2) + + x = x * first_conv_out # reduce artifacts + + x = self.final_conv(torch.cat([mix, x], 1)) + + x = self.cws2cac(x) + + if self.num_target_instruments > 1: + b, c, f, t = x.shape + x = x.reshape(b, self.num_target_instruments, -1, f, t) + + x = self.stft.inverse(x) + return x diff --git a/models/upernet_swin_transformers.py b/models/upernet_swin_transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..afb49773ac56ff95c8b91147dfcc0449490aafc1 --- /dev/null +++ b/models/upernet_swin_transformers.py @@ -0,0 +1,228 @@ +from functools import partial +import torch +import torch.nn as nn +from transformers import UperNetForSemanticSegmentation + + +class STFT: + def __init__(self, config): + self.n_fft = config.n_fft + self.hop_length = config.hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = config.dim_f + + def __call__(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + return x[..., :self.dim_f, :] + + def inverse(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.j + x = torch.istft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True + ) + x = x.reshape([*batch_dims, 2, -1]) + return x + + +def get_norm(norm_type): + def norm(c, norm_type): + if norm_type == 'BatchNorm': + return nn.BatchNorm2d(c) + elif norm_type == 'InstanceNorm': + return nn.InstanceNorm2d(c, affine=True) + elif 'GroupNorm' in norm_type: + g = int(norm_type.replace('GroupNorm', '')) + return nn.GroupNorm(num_groups=g, num_channels=c) + else: + return nn.Identity() + + return partial(norm, norm_type=norm_type) + + +def get_act(act_type): + if act_type == 'gelu': + return nn.GELU() + elif act_type == 'relu': + return nn.ReLU() + elif act_type[:3] == 'elu': + alpha = float(act_type.replace('elu', '')) + return nn.ELU(alpha) + else: + raise Exception + + +class Upscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + ) + + def forward(self, x): + return self.conv(x) + + +class Downscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + ) + + def forward(self, x): + return self.conv(x) + + +class TFC_TDF(nn.Module): + def __init__(self, in_c, c, l, f, bn, norm, act): + super().__init__() + + self.blocks = nn.ModuleList() + for i in range(l): + block = nn.Module() + + block.tfc1 = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_c, c, 3, 1, 1, bias=False), + ) + block.tdf = nn.Sequential( + norm(c), + act, + nn.Linear(f, f // bn, bias=False), + norm(c), + act, + nn.Linear(f // bn, f, bias=False), + ) + block.tfc2 = nn.Sequential( + norm(c), + act, + nn.Conv2d(c, c, 3, 1, 1, bias=False), + ) + block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False) + + self.blocks.append(block) + in_c = c + + def forward(self, x): + for block in self.blocks: + s = block.shortcut(x) + x = block.tfc1(x) + x = x + block.tdf(x) + x = block.tfc2(x) + x = x + s + return x + + +class Swin_UperNet_Model(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + act = get_act(act_type=config.model.act) + + self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) + self.num_subbands = config.model.num_subbands + + dim_c = self.num_subbands * config.audio.num_channels * 2 + c = config.model.num_channels + f = config.audio.dim_f // self.num_subbands + + self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) + + self.swin_upernet_model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-large") + + self.swin_upernet_model.auxiliary_head.classifier = nn.Conv2d(256, c, kernel_size=(1, 1), stride=(1, 1)) + self.swin_upernet_model.decode_head.classifier = nn.Conv2d(512, c, kernel_size=(1, 1), stride=(1, 1)) + self.swin_upernet_model.backbone.embeddings.patch_embeddings.projection = nn.Conv2d(c, 192, kernel_size=(4, 4), stride=(4, 4)) + + self.final_conv = nn.Sequential( + nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), + act, + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + ) + + self.stft = STFT(config.audio) + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, x): + + x = self.stft(x) + + mix = x = self.cac2cws(x) + + first_conv_out = x = self.first_conv(x) + + x = x.transpose(-1, -2) + + x = self.swin_upernet_model(x).logits + + x = x.transpose(-1, -2) + + x = x * first_conv_out # reduce artifacts + + x = self.final_conv(torch.cat([mix, x], 1)) + + x = self.cws2cac(x) + + if self.num_target_instruments > 1: + b, c, f, t = x.shape + x = x.reshape(b, self.num_target_instruments, -1, f, t) + + x = self.stft.inverse(x) + return x + + +if __name__ == "__main__": + model = UperNetForSemanticSegmentation.from_pretrained("./results/", ignore_mismatched_sizes=True) + print(model) + print(model.auxiliary_head.classifier) + print(model.decode_head.classifier) + + x = torch.zeros((2, 16, 512, 512), dtype=torch.float32) + res = model(x) + print(res.logits.shape) + model.save_pretrained('./results/') \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3e14756d70460750b0df816225c7c8a226d8e93c --- /dev/null +++ b/utils.py @@ -0,0 +1,203 @@ +# coding: utf-8 +__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' + +import time +import numpy as np +import torch +import torch.nn as nn +import yaml +from ml_collections import ConfigDict +from omegaconf import OmegaConf + + +def get_model_from_config(model_type, config_path): + with open(config_path) as f: + if model_type == 'htdemucs': + config = OmegaConf.load(config_path) + else: + config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) + + if model_type == 'mdx23c': + from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net + model = TFC_TDF_net(config) + elif model_type == 'htdemucs': + from models.demucs4ht import get_model + model = get_model(config) + elif model_type == 'segm_models': + from models.segm_models import Segm_Models_Net + model = Segm_Models_Net(config) + elif model_type == 'torchseg': + from models.torchseg_models import Torchseg_Net + model = Torchseg_Net(config) + elif model_type == 'mel_band_roformer': + from models.bs_roformer import MelBandRoformer + model = MelBandRoformer( + **dict(config.model) + ) + elif model_type == 'bs_roformer': + from models.bs_roformer import BSRoformer + model = BSRoformer( + **dict(config.model) + ) + elif model_type == 'swin_upernet': + from models.upernet_swin_transformers import Swin_UperNet_Model + model = Swin_UperNet_Model(config) + elif model_type == 'bandit': + from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple + model = MultiMaskMultiSourceBandSplitRNNSimple( + **config.model + ) + elif model_type == 'scnet_unofficial': + from models.scnet_unofficial import SCNet + model = SCNet( + **config.model + ) + elif model_type == 'scnet': + from models.scnet import SCNet + model = SCNet( + **config.model + ) + else: + print('Unknown model: {}'.format(model_type)) + model = None + + return model, config + + +def demix_track(config, model, mix, device): + C = config.audio.chunk_size + N = config.inference.num_overlap + fade_size = C // 10 + step = int(C // N) + border = C - step + batch_size = config.inference.batch_size + + length_init = mix.shape[-1] + + # Do pad from the beginning and end to account floating window results better + if length_init > 2 * border and (border > 0): + mix = nn.functional.pad(mix, (border, border), mode='reflect') + + # Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment + window_size = C + fadein = torch.linspace(0, 1, fade_size) + fadeout = torch.linspace(1, 0, fade_size) + window_start = torch.ones(window_size) + window_middle = torch.ones(window_size) + window_finish = torch.ones(window_size) + window_start[-fade_size:] *= fadeout # First audio chunk, no fadein + window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout + window_middle[-fade_size:] *= fadeout + window_middle[:fade_size] *= fadein + + with torch.cuda.amp.autocast(): + with torch.inference_mode(): + if config.training.target_instrument is not None: + req_shape = (1, ) + tuple(mix.shape) + else: + req_shape = (len(config.training.instruments),) + tuple(mix.shape) + + result = torch.zeros(req_shape, dtype=torch.float32) + counter = torch.zeros(req_shape, dtype=torch.float32) + i = 0 + batch_data = [] + batch_locations = [] + while i < mix.shape[1]: + # print(i, i + C, mix.shape[1]) + part = mix[:, i:i + C].to(device) + length = part.shape[-1] + if length < C: + if length > C // 2 + 1: + part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') + else: + part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) + batch_data.append(part) + batch_locations.append((i, length)) + i += step + + if len(batch_data) >= batch_size or (i >= mix.shape[1]): + arr = torch.stack(batch_data, dim=0) + x = model(arr) + + window = window_middle + if i - step == 0: # First audio chunk, no fadein + window = window_start + elif i >= mix.shape[1]: # Last audio chunk, no fadeout + window = window_finish + + for j in range(len(batch_locations)): + start, l = batch_locations[j] + result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l] + counter[..., start:start+l] += window[..., :l] + + batch_data = [] + batch_locations = [] + + estimated_sources = result / counter + estimated_sources = estimated_sources.cpu().numpy() + np.nan_to_num(estimated_sources, copy=False, nan=0.0) + + if length_init > 2 * border and (border > 0): + # Remove pad + estimated_sources = estimated_sources[..., border:-border] + + if config.training.target_instrument is None: + return {k: v for k, v in zip(config.training.instruments, estimated_sources)} + else: + return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)} + + +def demix_track_demucs(config, model, mix, device): + S = len(config.training.instruments) + C = config.training.samplerate * config.training.segment + N = config.inference.num_overlap + batch_size = config.inference.batch_size + step = C // N + # print(S, C, N, step, mix.shape, mix.device) + + with torch.cuda.amp.autocast(enabled=config.training.use_amp): + with torch.inference_mode(): + req_shape = (S, ) + tuple(mix.shape) + result = torch.zeros(req_shape, dtype=torch.float32) + counter = torch.zeros(req_shape, dtype=torch.float32) + i = 0 + batch_data = [] + batch_locations = [] + while i < mix.shape[1]: + # print(i, i + C, mix.shape[1]) + part = mix[:, i:i + C].to(device) + length = part.shape[-1] + if length < C: + part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) + batch_data.append(part) + batch_locations.append((i, length)) + i += step + + if len(batch_data) >= batch_size or (i >= mix.shape[1]): + arr = torch.stack(batch_data, dim=0) + x = model(arr) + for j in range(len(batch_locations)): + start, l = batch_locations[j] + result[..., start:start+l] += x[j][..., :l].cpu() + counter[..., start:start+l] += 1. + batch_data = [] + batch_locations = [] + + estimated_sources = result / counter + estimated_sources = estimated_sources.cpu().numpy() + np.nan_to_num(estimated_sources, copy=False, nan=0.0) + + if S > 1: + return {k: v for k, v in zip(config.training.instruments, estimated_sources)} + else: + return estimated_sources + + +def sdr(references, estimates): + # compute SDR for one song + delta = 1e-7 # avoid numerical errors + num = np.sum(np.square(references), axis=(1, 2)) + den = np.sum(np.square(references - estimates), axis=(1, 2)) + num += delta + den += delta + return 10 * np.log10(num / den)