poiqazwsx commited on
Commit
51e2f90
1 Parent(s): 195b331

Upload 57 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +22 -0
  2. inference.py +113 -0
  3. models/bandit/core/__init__.py +744 -0
  4. models/bandit/core/data/__init__.py +2 -0
  5. models/bandit/core/data/_types.py +18 -0
  6. models/bandit/core/data/augmentation.py +107 -0
  7. models/bandit/core/data/augmented.py +35 -0
  8. models/bandit/core/data/base.py +69 -0
  9. models/bandit/core/data/dnr/__init__.py +0 -0
  10. models/bandit/core/data/dnr/datamodule.py +74 -0
  11. models/bandit/core/data/dnr/dataset.py +392 -0
  12. models/bandit/core/data/dnr/preprocess.py +54 -0
  13. models/bandit/core/data/musdb/__init__.py +0 -0
  14. models/bandit/core/data/musdb/datamodule.py +77 -0
  15. models/bandit/core/data/musdb/dataset.py +280 -0
  16. models/bandit/core/data/musdb/preprocess.py +238 -0
  17. models/bandit/core/data/musdb/validation.yaml +15 -0
  18. models/bandit/core/loss/__init__.py +2 -0
  19. models/bandit/core/loss/_complex.py +34 -0
  20. models/bandit/core/loss/_multistem.py +45 -0
  21. models/bandit/core/loss/_timefreq.py +113 -0
  22. models/bandit/core/loss/snr.py +146 -0
  23. models/bandit/core/metrics/__init__.py +9 -0
  24. models/bandit/core/metrics/_squim.py +383 -0
  25. models/bandit/core/metrics/snr.py +150 -0
  26. models/bandit/core/model/__init__.py +3 -0
  27. models/bandit/core/model/_spectral.py +58 -0
  28. models/bandit/core/model/bsrnn/__init__.py +23 -0
  29. models/bandit/core/model/bsrnn/bandsplit.py +139 -0
  30. models/bandit/core/model/bsrnn/core.py +661 -0
  31. models/bandit/core/model/bsrnn/maskestim.py +347 -0
  32. models/bandit/core/model/bsrnn/tfmodel.py +317 -0
  33. models/bandit/core/model/bsrnn/utils.py +583 -0
  34. models/bandit/core/model/bsrnn/wrapper.py +882 -0
  35. models/bandit/core/utils/__init__.py +0 -0
  36. models/bandit/core/utils/audio.py +463 -0
  37. models/bandit/model_from_config.py +31 -0
  38. models/bs_roformer/__init__.py +2 -0
  39. models/bs_roformer/attend.py +120 -0
  40. models/bs_roformer/bs_roformer.py +577 -0
  41. models/bs_roformer/mel_band_roformer.py +637 -0
  42. models/demucs4ht.py +713 -0
  43. models/mdx23c_tfc_tdf_v3.py +242 -0
  44. models/scnet/__init__.py +1 -0
  45. models/scnet/scnet.py +373 -0
  46. models/scnet/separation.py +178 -0
  47. models/scnet_unofficial/__init__.py +1 -0
  48. models/scnet_unofficial/modules/__init__.py +3 -0
  49. models/scnet_unofficial/modules/dualpath_rnn.py +228 -0
  50. models/scnet_unofficial/modules/sd_encoder.py +285 -0
app.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ DESCRIPTION = """
4
+ # audio sep
5
+ being made
6
+ """
7
+
8
+ theme = gr.themes.Base(
9
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
10
+ )
11
+ with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
12
+ gr.Markdown(DESCRIPTION)
13
+ gr.DuplicateButton(
14
+ value="Duplicate Space for private use",
15
+ elem_id="duplicate-button",
16
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
17
+ )
18
+
19
+
20
+
21
+
22
+ demo.queue(max_size=20, api_open=False).launch(show_api=False)
inference.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
3
+
4
+ import argparse
5
+ import time
6
+ import librosa
7
+ from tqdm import tqdm
8
+ import sys
9
+ import os
10
+ import glob
11
+ import torch
12
+ import numpy as np
13
+ import soundfile as sf
14
+ import torch.nn as nn
15
+ from utils import demix_track, demix_track_demucs, get_model_from_config
16
+
17
+ import warnings
18
+ warnings.filterwarnings("ignore")
19
+
20
+
21
+ def run_folder(model, args, config, device, verbose=False):
22
+ start_time = time.time()
23
+ model.eval()
24
+ all_mixtures_path = glob.glob(args.input_folder + '/*.*')
25
+ print('Total files found: {}'.format(len(all_mixtures_path)))
26
+
27
+ instruments = config.training.instruments
28
+ if config.training.target_instrument is not None:
29
+ instruments = [config.training.target_instrument]
30
+
31
+ if not os.path.isdir(args.store_dir):
32
+ os.mkdir(args.store_dir)
33
+
34
+ if not verbose:
35
+ all_mixtures_path = tqdm(all_mixtures_path)
36
+
37
+ for path in all_mixtures_path:
38
+ if not verbose:
39
+ all_mixtures_path.set_postfix({'track': os.path.basename(path)})
40
+ try:
41
+ # mix, sr = sf.read(path)
42
+ mix, sr = librosa.load(path, sr=44100, mono=False)
43
+ mix = mix.T
44
+ except Exception as e:
45
+ print('Can read track: {}'.format(path))
46
+ print('Error message: {}'.format(str(e)))
47
+ continue
48
+
49
+ # Convert mono to stereo if needed
50
+ if len(mix.shape) == 1:
51
+ mix = np.stack([mix, mix], axis=-1)
52
+
53
+ mixture = torch.tensor(mix.T, dtype=torch.float32)
54
+ if args.model_type == 'htdemucs':
55
+ res = demix_track_demucs(config, model, mixture, device)
56
+ else:
57
+ res = demix_track(config, model, mixture, device)
58
+ for instr in instruments:
59
+ sf.write("{}/{}_{}.wav".format(args.store_dir, os.path.basename(path)[:-4], instr), res[instr].T, sr, subtype='FLOAT')
60
+
61
+ if 'vocals' in instruments and args.extract_instrumental:
62
+ instrum_file_name = "{}/{}_{}.wav".format(args.store_dir, os.path.basename(path)[:-4], 'instrumental')
63
+ sf.write(instrum_file_name, mix - res['vocals'].T, sr, subtype='FLOAT')
64
+
65
+ time.sleep(1)
66
+ print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
67
+
68
+
69
+ def proc_folder(args):
70
+ parser = argparse.ArgumentParser()
71
+ parser.add_argument("--model_type", type=str, default='mdx23c', help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit")
72
+ parser.add_argument("--config_path", type=str, help="path to config file")
73
+ parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to valid weights")
74
+ parser.add_argument("--input_folder", type=str, help="folder with mixtures to process")
75
+ parser.add_argument("--store_dir", default="", type=str, help="path to store results as wav file")
76
+ parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='list of gpu ids')
77
+ parser.add_argument("--extract_instrumental", action='store_true', help="invert vocals to get instrumental if provided")
78
+ if args is None:
79
+ args = parser.parse_args()
80
+ else:
81
+ args = parser.parse_args(args)
82
+
83
+ torch.backends.cudnn.benchmark = True
84
+
85
+ model, config = get_model_from_config(args.model_type, args.config_path)
86
+ if args.start_check_point != '':
87
+ print('Start from checkpoint: {}'.format(args.start_check_point))
88
+ state_dict = torch.load(args.start_check_point)
89
+ if args.model_type == 'htdemucs':
90
+ # Fix for htdemucs pround etrained models
91
+ if 'state' in state_dict:
92
+ state_dict = state_dict['state']
93
+ model.load_state_dict(state_dict)
94
+ print("Instruments: {}".format(config.training.instruments))
95
+
96
+ if torch.cuda.is_available():
97
+ device_ids = args.device_ids
98
+ if type(device_ids)==int:
99
+ device = torch.device(f'cuda:{device_ids}')
100
+ model = model.to(device)
101
+ else:
102
+ device = torch.device(f'cuda:{device_ids[0]}')
103
+ model = nn.DataParallel(model, device_ids=device_ids).to(device)
104
+ else:
105
+ device = 'cpu'
106
+ print('CUDA is not avilable. Run inference on CPU. It will be very slow...')
107
+ model = model.to(device)
108
+
109
+ run_folder(model, args, config, device, verbose=False)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ proc_folder(None)
models/bandit/core/__init__.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from collections import defaultdict
3
+ from itertools import chain, combinations
4
+ from typing import (
5
+ Any,
6
+ Dict,
7
+ Iterator,
8
+ Mapping, Optional,
9
+ Tuple, Type,
10
+ TypedDict
11
+ )
12
+
13
+ import pytorch_lightning as pl
14
+ import torch
15
+ import torchaudio as ta
16
+ import torchmetrics as tm
17
+ from asteroid import losses as asteroid_losses
18
+ # from deepspeed.ops.adam import DeepSpeedCPUAdam
19
+ # from geoopt import optim as gooptim
20
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
21
+ from torch import nn, optim
22
+ from torch.optim import lr_scheduler
23
+ from torch.optim.lr_scheduler import LRScheduler
24
+
25
+ from models.bandit.core import loss, metrics as metrics_, model
26
+ from models.bandit.core.data._types import BatchedDataDict
27
+ from models.bandit.core.data.augmentation import BaseAugmentor, StemAugmentor
28
+ from models.bandit.core.utils import audio as audio_
29
+ from models.bandit.core.utils.audio import BaseFader
30
+
31
+ # from pandas.io.json._normalize import nested_to_record
32
+
33
+ ConfigDict = TypedDict('ConfigDict', {'name': str, 'kwargs': Dict[str, Any]})
34
+
35
+
36
+ class SchedulerConfigDict(ConfigDict):
37
+ monitor: str
38
+
39
+
40
+ OptimizerSchedulerConfigDict = TypedDict(
41
+ 'OptimizerSchedulerConfigDict',
42
+ {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
43
+ total=False
44
+ )
45
+
46
+
47
+ class LRSchedulerReturnDict(TypedDict, total=False):
48
+ scheduler: LRScheduler
49
+ monitor: str
50
+
51
+
52
+ class ConfigureOptimizerReturnDict(TypedDict, total=False):
53
+ optimizer: torch.optim.Optimizer
54
+ lr_scheduler: LRSchedulerReturnDict
55
+
56
+
57
+ OutputType = Dict[str, Any]
58
+ MetricsType = Dict[str, torch.Tensor]
59
+
60
+
61
+ def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
62
+
63
+ if name == "DeepSpeedCPUAdam":
64
+ return DeepSpeedCPUAdam
65
+
66
+ for module in [optim, gooptim]:
67
+ if name in module.__dict__:
68
+ return module.__dict__[name]
69
+
70
+ raise NameError
71
+
72
+
73
+ def parse_optimizer_config(
74
+ config: OptimizerSchedulerConfigDict,
75
+ parameters: Iterator[nn.Parameter]
76
+ ) -> ConfigureOptimizerReturnDict:
77
+ optim_class = get_optimizer_class(config["optimizer"]["name"])
78
+ optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
79
+
80
+ optim_dict: ConfigureOptimizerReturnDict = {
81
+ "optimizer": optimizer,
82
+ }
83
+
84
+ if "scheduler" in config:
85
+
86
+ lr_scheduler_class_ = config["scheduler"]["name"]
87
+ lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
88
+ lr_scheduler_dict: LRSchedulerReturnDict = {
89
+ "scheduler": lr_scheduler_class(
90
+ optimizer,
91
+ **config["scheduler"]["kwargs"]
92
+ )
93
+ }
94
+
95
+ if lr_scheduler_class_ == "ReduceLROnPlateau":
96
+ lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
97
+
98
+ optim_dict["lr_scheduler"] = lr_scheduler_dict
99
+
100
+ return optim_dict
101
+
102
+
103
+ def parse_model_config(config: ConfigDict) -> Any:
104
+ name = config["name"]
105
+
106
+ for module in [model]:
107
+ if name in module.__dict__:
108
+ return module.__dict__[name](**config["kwargs"])
109
+
110
+ raise NameError
111
+
112
+
113
+ _LEGACY_LOSS_NAMES = ["HybridL1Loss"]
114
+
115
+
116
+ def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
117
+ name = config["name"]
118
+
119
+ if name == "HybridL1Loss":
120
+ return loss.TimeFreqL1Loss(**config["kwargs"])
121
+
122
+ raise NameError
123
+
124
+
125
+ def parse_loss_config(config: ConfigDict) -> nn.Module:
126
+ name = config["name"]
127
+
128
+ if name in _LEGACY_LOSS_NAMES:
129
+ return _parse_legacy_loss_config(config)
130
+
131
+ for module in [loss, nn.modules.loss, asteroid_losses]:
132
+ if name in module.__dict__:
133
+ # print(config["kwargs"])
134
+ return module.__dict__[name](**config["kwargs"])
135
+
136
+ raise NameError
137
+
138
+
139
+ def get_metric(config: ConfigDict) -> tm.Metric:
140
+ name = config["name"]
141
+
142
+ for module in [tm, metrics_]:
143
+ if name in module.__dict__:
144
+ return module.__dict__[name](**config["kwargs"])
145
+ raise NameError
146
+
147
+
148
+ def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
149
+ metrics = {}
150
+
151
+ for metric in config:
152
+ metrics[metric] = get_metric(config[metric])
153
+
154
+ return tm.MetricCollection(metrics)
155
+
156
+
157
+ def parse_fader_config(config: ConfigDict) -> BaseFader:
158
+ name = config["name"]
159
+
160
+ for module in [audio_]:
161
+ if name in module.__dict__:
162
+ return module.__dict__[name](**config["kwargs"])
163
+
164
+ raise NameError
165
+
166
+
167
+ class LightningSystem(pl.LightningModule):
168
+ _VOX_STEMS = ["speech", "vocals"]
169
+ _BG_STEMS = ["background", "effects", "mne"]
170
+
171
+ def __init__(
172
+ self,
173
+ config: Dict,
174
+ loss_adjustment: float = 1.0,
175
+ attach_fader: bool = False
176
+ ) -> None:
177
+ super().__init__()
178
+ self.optimizer_config = config["optimizer"]
179
+ self.model = parse_model_config(config["model"])
180
+ self.loss = parse_loss_config(config["loss"])
181
+ self.metrics = nn.ModuleDict(
182
+ {
183
+ stem: parse_metric_config(config["metrics"]["dev"])
184
+ for stem in self.model.stems
185
+ }
186
+ )
187
+
188
+ self.metrics.disallow_fsdp = True
189
+
190
+ self.test_metrics = nn.ModuleDict(
191
+ {
192
+ stem: parse_metric_config(config["metrics"]["test"])
193
+ for stem in self.model.stems
194
+ }
195
+ )
196
+
197
+ self.test_metrics.disallow_fsdp = True
198
+
199
+ self.fs = config["model"]["kwargs"]["fs"]
200
+
201
+ self.fader_config = config["inference"]["fader"]
202
+ if attach_fader:
203
+ self.fader = parse_fader_config(config["inference"]["fader"])
204
+ else:
205
+ self.fader = None
206
+
207
+ self.augmentation: Optional[BaseAugmentor]
208
+ if config.get("augmentation", None) is not None:
209
+ self.augmentation = StemAugmentor(**config["augmentation"])
210
+ else:
211
+ self.augmentation = None
212
+
213
+ self.predict_output_path: Optional[str] = None
214
+ self.loss_adjustment = loss_adjustment
215
+
216
+ self.val_prefix = None
217
+ self.test_prefix = None
218
+
219
+
220
+ def configure_optimizers(self) -> Any:
221
+ return parse_optimizer_config(
222
+ self.optimizer_config,
223
+ self.trainer.model.parameters()
224
+ )
225
+
226
+ def compute_loss(self, batch: BatchedDataDict, output: OutputType) -> Dict[
227
+ str, torch.Tensor]:
228
+ return {"loss": self.loss(output, batch)}
229
+
230
+ def update_metrics(
231
+ self,
232
+ batch: BatchedDataDict,
233
+ output: OutputType,
234
+ mode: str
235
+ ) -> None:
236
+
237
+ if mode == "test":
238
+ metrics = self.test_metrics
239
+ else:
240
+ metrics = self.metrics
241
+
242
+ for stem, metric in metrics.items():
243
+
244
+ if stem == "mne:+":
245
+ stem = "mne"
246
+
247
+ # print(f"matching for {stem}")
248
+ if mode == "train":
249
+ metric.update(
250
+ output["audio"][stem],#.cpu(),
251
+ batch["audio"][stem],#.cpu()
252
+ )
253
+ else:
254
+ if stem not in batch["audio"]:
255
+ matched = False
256
+ if stem in self._VOX_STEMS:
257
+ for bstem in self._VOX_STEMS:
258
+ if bstem in batch["audio"]:
259
+ batch["audio"][stem] = batch["audio"][bstem]
260
+ matched = True
261
+ break
262
+ elif stem in self._BG_STEMS:
263
+ for bstem in self._BG_STEMS:
264
+ if bstem in batch["audio"]:
265
+ batch["audio"][stem] = batch["audio"][bstem]
266
+ matched = True
267
+ break
268
+ else:
269
+ matched = True
270
+
271
+ # print(batch["audio"].keys())
272
+
273
+ if matched:
274
+ # print(f"matched {stem}!")
275
+ if stem == "mne" and "mne" not in output["audio"]:
276
+ output["audio"]["mne"] = output["audio"]["music"] + output["audio"]["effects"]
277
+
278
+ metric.update(
279
+ output["audio"][stem],#.cpu(),
280
+ batch["audio"][stem],#.cpu(),
281
+ )
282
+
283
+ # print(metric.compute())
284
+ def compute_metrics(self, mode: str="dev") -> Dict[
285
+ str, torch.Tensor]:
286
+
287
+ if mode == "test":
288
+ metrics = self.test_metrics
289
+ else:
290
+ metrics = self.metrics
291
+
292
+ metric_dict = {}
293
+
294
+ for stem, metric in metrics.items():
295
+ md = metric.compute()
296
+ metric_dict.update(
297
+ {f"{stem}/{k}": v for k, v in md.items()}
298
+ )
299
+
300
+ self.log_dict(metric_dict, prog_bar=True, logger=False)
301
+
302
+ return metric_dict
303
+
304
+ def reset_metrics(self, test_mode: bool = False) -> None:
305
+
306
+ if test_mode:
307
+ metrics = self.test_metrics
308
+ else:
309
+ metrics = self.metrics
310
+
311
+ for _, metric in metrics.items():
312
+ metric.reset()
313
+
314
+
315
+ def forward(self, batch: BatchedDataDict) -> Any:
316
+ batch, output = self.model(batch)
317
+
318
+
319
+ return batch, output
320
+
321
+ def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
322
+ batch, output = self.forward(batch)
323
+ # print(batch)
324
+ # print(output)
325
+ loss_dict = self.compute_loss(batch, output)
326
+
327
+ with torch.no_grad():
328
+ self.update_metrics(batch, output, mode=mode)
329
+
330
+ if mode == "train":
331
+ self.log("loss", loss_dict["loss"], prog_bar=True)
332
+
333
+ return output, loss_dict
334
+
335
+
336
+ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
337
+
338
+ if self.augmentation is not None:
339
+ with torch.no_grad():
340
+ batch = self.augmentation(batch)
341
+
342
+ _, loss_dict = self.common_step(batch, mode="train")
343
+
344
+ with torch.inference_mode():
345
+ self.log_dict_with_prefix(
346
+ loss_dict,
347
+ "train",
348
+ batch_size=batch["audio"]["mixture"].shape[0]
349
+ )
350
+
351
+ loss_dict["loss"] *= self.loss_adjustment
352
+
353
+ return loss_dict
354
+
355
+ def on_train_batch_end(
356
+ self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
357
+ ) -> None:
358
+
359
+ metric_dict = self.compute_metrics()
360
+ self.log_dict_with_prefix(metric_dict, "train")
361
+ self.reset_metrics()
362
+
363
+ def validation_step(
364
+ self,
365
+ batch: BatchedDataDict,
366
+ batch_idx: int,
367
+ dataloader_idx: int = 0
368
+ ) -> Dict[str, Any]:
369
+
370
+ with torch.inference_mode():
371
+ curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
372
+
373
+ if curr_val_prefix != self.val_prefix:
374
+ # print(f"Switching to validation dataloader {dataloader_idx}")
375
+ if self.val_prefix is not None:
376
+ self._on_validation_epoch_end()
377
+ self.val_prefix = curr_val_prefix
378
+ _, loss_dict = self.common_step(batch, mode="val")
379
+
380
+ self.log_dict_with_prefix(
381
+ loss_dict,
382
+ self.val_prefix,
383
+ batch_size=batch["audio"]["mixture"].shape[0],
384
+ prog_bar=True,
385
+ add_dataloader_idx=False
386
+ )
387
+
388
+ return loss_dict
389
+
390
+ def on_validation_epoch_end(self) -> None:
391
+ self._on_validation_epoch_end()
392
+
393
+ def _on_validation_epoch_end(self) -> None:
394
+ metric_dict = self.compute_metrics()
395
+ self.log_dict_with_prefix(metric_dict, self.val_prefix, prog_bar=True,
396
+ add_dataloader_idx=False)
397
+ # self.logger.save()
398
+ # print(self.val_prefix, "Validation metrics:", metric_dict)
399
+ self.reset_metrics()
400
+
401
+
402
+ def old_predtest_step(
403
+ self,
404
+ batch: BatchedDataDict,
405
+ batch_idx: int,
406
+ dataloader_idx: int = 0
407
+ ) -> Tuple[BatchedDataDict, OutputType]:
408
+
409
+ audio_batch = batch["audio"]["mixture"]
410
+ track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
411
+
412
+ output_list_of_dicts = [
413
+ self.fader(
414
+ audio[None, ...],
415
+ lambda a: self.test_forward(a, track)
416
+ )
417
+ for audio, track in zip(audio_batch, track_batch)
418
+ ]
419
+
420
+ output_dict_of_lists = defaultdict(list)
421
+
422
+ for output_dict in output_list_of_dicts:
423
+ for stem, audio in output_dict.items():
424
+ output_dict_of_lists[stem].append(audio)
425
+
426
+ output = {
427
+ "audio": {
428
+ stem: torch.concat(output_list, dim=0)
429
+ for stem, output_list in output_dict_of_lists.items()
430
+ }
431
+ }
432
+
433
+ return batch, output
434
+
435
+ def predtest_step(
436
+ self,
437
+ batch: BatchedDataDict,
438
+ batch_idx: int = -1,
439
+ dataloader_idx: int = 0
440
+ ) -> Tuple[BatchedDataDict, OutputType]:
441
+
442
+ if getattr(self.model, "bypass_fader", False):
443
+ batch, output = self.model(batch)
444
+ else:
445
+ audio_batch = batch["audio"]["mixture"]
446
+ output = self.fader(
447
+ audio_batch,
448
+ lambda a: self.test_forward(a, "", batch=batch)
449
+ )
450
+
451
+ return batch, output
452
+
453
+ def test_forward(
454
+ self,
455
+ audio: torch.Tensor,
456
+ track: str = "",
457
+ batch: BatchedDataDict = None
458
+ ) -> torch.Tensor:
459
+
460
+ if self.fader is None:
461
+ self.attach_fader()
462
+
463
+ cond = batch.get("condition", None)
464
+
465
+ if cond is not None and cond.shape[0] == 1:
466
+ cond = cond.repeat(audio.shape[0], 1)
467
+
468
+ _, output = self.forward(
469
+ {"audio": {"mixture": audio},
470
+ "track": track,
471
+ "condition": cond,
472
+ }
473
+ ) # TODO: support track properly
474
+
475
+ return output["audio"]
476
+
477
+ def on_test_epoch_start(self) -> None:
478
+ self.attach_fader(force_reattach=True)
479
+
480
+ def test_step(
481
+ self,
482
+ batch: BatchedDataDict,
483
+ batch_idx: int,
484
+ dataloader_idx: int = 0
485
+ ) -> Any:
486
+ curr_test_prefix = f"test{dataloader_idx}"
487
+
488
+ # print(batch["audio"].keys())
489
+
490
+ if curr_test_prefix != self.test_prefix:
491
+ # print(f"Switching to test dataloader {dataloader_idx}")
492
+ if self.test_prefix is not None:
493
+ self._on_test_epoch_end()
494
+ self.test_prefix = curr_test_prefix
495
+
496
+ with torch.inference_mode():
497
+ _, output = self.predtest_step(batch, batch_idx, dataloader_idx)
498
+ # print(output)
499
+ self.update_metrics(batch, output, mode="test")
500
+
501
+ return output
502
+
503
+ def on_test_epoch_end(self) -> None:
504
+ self._on_test_epoch_end()
505
+
506
+ def _on_test_epoch_end(self) -> None:
507
+ metric_dict = self.compute_metrics(mode="test")
508
+ self.log_dict_with_prefix(metric_dict, self.test_prefix, prog_bar=True,
509
+ add_dataloader_idx=False)
510
+ # self.logger.save()
511
+ # print(self.test_prefix, "Test metrics:", metric_dict)
512
+ self.reset_metrics()
513
+
514
+ def predict_step(
515
+ self,
516
+ batch: BatchedDataDict,
517
+ batch_idx: int = 0,
518
+ dataloader_idx: int = 0,
519
+ include_track_name: Optional[bool] = None,
520
+ get_no_vox_combinations: bool = True,
521
+ get_residual: bool = False,
522
+ treat_batch_as_channels: bool = False,
523
+ fs: Optional[int] = None,
524
+ ) -> Any:
525
+ assert self.predict_output_path is not None
526
+
527
+ batch_size = batch["audio"]["mixture"].shape[0]
528
+
529
+ if include_track_name is None:
530
+ include_track_name = batch_size > 1
531
+
532
+ with torch.inference_mode():
533
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
534
+ print('Pred test finished...')
535
+ torch.cuda.empty_cache()
536
+ metric_dict = {}
537
+
538
+ if get_residual:
539
+ mixture = batch["audio"]["mixture"]
540
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
541
+ residual = mixture - extracted
542
+ print(extracted.shape, mixture.shape, residual.shape)
543
+
544
+ output["audio"]["residual"] = residual
545
+
546
+ if get_no_vox_combinations:
547
+ no_vox_stems = [
548
+ stem for stem in output["audio"] if
549
+ stem not in self._VOX_STEMS
550
+ ]
551
+ no_vox_combinations = chain.from_iterable(
552
+ combinations(no_vox_stems, r) for r in
553
+ range(2, len(no_vox_stems) + 1)
554
+ )
555
+
556
+ for combination in no_vox_combinations:
557
+ combination_ = list(combination)
558
+ output["audio"]["+".join(combination_)] = sum(
559
+ [output["audio"][stem] for stem in combination_]
560
+ )
561
+
562
+ if treat_batch_as_channels:
563
+ for stem in output["audio"]:
564
+ output["audio"][stem] = output["audio"][stem].reshape(
565
+ 1, -1, output["audio"][stem].shape[-1]
566
+ )
567
+ batch_size = 1
568
+
569
+ for b in range(batch_size):
570
+ print("!!", b)
571
+ for stem in output["audio"]:
572
+ print(f"Saving audio for {stem} to {self.predict_output_path}")
573
+ track_name = batch["track"][b].split("/")[-1]
574
+
575
+ if batch.get("audio", {}).get(stem, None) is not None:
576
+ self.test_metrics[stem].reset()
577
+ metrics = self.test_metrics[stem](
578
+ batch["audio"][stem][[b], ...],
579
+ output["audio"][stem][[b], ...]
580
+ )
581
+ snr = metrics["snr"]
582
+ sisnr = metrics["sisnr"]
583
+ sdr = metrics["sdr"]
584
+ metric_dict[stem] = metrics
585
+ print(
586
+ track_name,
587
+ f"snr={snr:2.2f} dB",
588
+ f"sisnr={sisnr:2.2f}",
589
+ f"sdr={sdr:2.2f} dB",
590
+ )
591
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
592
+ else:
593
+ filename = f"{stem}.wav"
594
+
595
+ if include_track_name:
596
+ output_dir = os.path.join(
597
+ self.predict_output_path,
598
+ track_name
599
+ )
600
+ else:
601
+ output_dir = self.predict_output_path
602
+
603
+ os.makedirs(output_dir, exist_ok=True)
604
+
605
+ if fs is None:
606
+ fs = self.fs
607
+
608
+ ta.save(
609
+ os.path.join(output_dir, filename),
610
+ output["audio"][stem][b, ...].cpu(),
611
+ fs,
612
+ )
613
+
614
+ return metric_dict
615
+
616
+ def get_stems(
617
+ self,
618
+ batch: BatchedDataDict,
619
+ batch_idx: int = 0,
620
+ dataloader_idx: int = 0,
621
+ include_track_name: Optional[bool] = None,
622
+ get_no_vox_combinations: bool = True,
623
+ get_residual: bool = False,
624
+ treat_batch_as_channels: bool = False,
625
+ fs: Optional[int] = None,
626
+ ) -> Any:
627
+ assert self.predict_output_path is not None
628
+
629
+ batch_size = batch["audio"]["mixture"].shape[0]
630
+
631
+ if include_track_name is None:
632
+ include_track_name = batch_size > 1
633
+
634
+ with torch.inference_mode():
635
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
636
+ torch.cuda.empty_cache()
637
+ metric_dict = {}
638
+
639
+ if get_residual:
640
+ mixture = batch["audio"]["mixture"]
641
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
642
+ residual = mixture - extracted
643
+ # print(extracted.shape, mixture.shape, residual.shape)
644
+
645
+ output["audio"]["residual"] = residual
646
+
647
+ if get_no_vox_combinations:
648
+ no_vox_stems = [
649
+ stem for stem in output["audio"] if
650
+ stem not in self._VOX_STEMS
651
+ ]
652
+ no_vox_combinations = chain.from_iterable(
653
+ combinations(no_vox_stems, r) for r in
654
+ range(2, len(no_vox_stems) + 1)
655
+ )
656
+
657
+ for combination in no_vox_combinations:
658
+ combination_ = list(combination)
659
+ output["audio"]["+".join(combination_)] = sum(
660
+ [output["audio"][stem] for stem in combination_]
661
+ )
662
+
663
+ if treat_batch_as_channels:
664
+ for stem in output["audio"]:
665
+ output["audio"][stem] = output["audio"][stem].reshape(
666
+ 1, -1, output["audio"][stem].shape[-1]
667
+ )
668
+ batch_size = 1
669
+
670
+ result = {}
671
+ for b in range(batch_size):
672
+ for stem in output["audio"]:
673
+ track_name = batch["track"][b].split("/")[-1]
674
+
675
+ if batch.get("audio", {}).get(stem, None) is not None:
676
+ self.test_metrics[stem].reset()
677
+ metrics = self.test_metrics[stem](
678
+ batch["audio"][stem][[b], ...],
679
+ output["audio"][stem][[b], ...]
680
+ )
681
+ snr = metrics["snr"]
682
+ sisnr = metrics["sisnr"]
683
+ sdr = metrics["sdr"]
684
+ metric_dict[stem] = metrics
685
+ print(
686
+ track_name,
687
+ f"snr={snr:2.2f} dB",
688
+ f"sisnr={sisnr:2.2f}",
689
+ f"sdr={sdr:2.2f} dB",
690
+ )
691
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
692
+ else:
693
+ filename = f"{stem}.wav"
694
+
695
+ if include_track_name:
696
+ output_dir = os.path.join(
697
+ self.predict_output_path,
698
+ track_name
699
+ )
700
+ else:
701
+ output_dir = self.predict_output_path
702
+
703
+ os.makedirs(output_dir, exist_ok=True)
704
+
705
+ if fs is None:
706
+ fs = self.fs
707
+
708
+ result[stem] = output["audio"][stem][b, ...].cpu().numpy()
709
+
710
+ return result
711
+
712
+ def load_state_dict(
713
+ self, state_dict: Mapping[str, Any], strict: bool = False
714
+ ) -> Any:
715
+
716
+ return super().load_state_dict(state_dict, strict=False)
717
+
718
+
719
+ def set_predict_output_path(self, path: str) -> None:
720
+ self.predict_output_path = path
721
+ os.makedirs(self.predict_output_path, exist_ok=True)
722
+
723
+ self.attach_fader()
724
+
725
+ def attach_fader(self, force_reattach=False) -> None:
726
+ if self.fader is None or force_reattach:
727
+ self.fader = parse_fader_config(self.fader_config)
728
+ self.fader.to(self.device)
729
+
730
+
731
+ def log_dict_with_prefix(
732
+ self,
733
+ dict_: Dict[str, torch.Tensor],
734
+ prefix: str,
735
+ batch_size: Optional[int] = None,
736
+ **kwargs: Any
737
+ ) -> None:
738
+ self.log_dict(
739
+ {f"{prefix}/{k}": v for k, v in dict_.items()},
740
+ batch_size=batch_size,
741
+ logger=True,
742
+ sync_dist=True,
743
+ **kwargs,
744
+ )
models/bandit/core/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .dnr.datamodule import DivideAndRemasterDataModule
2
+ from .musdb.datamodule import MUSDB18DataModule
models/bandit/core/data/_types.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Sequence, TypedDict
2
+
3
+ import torch
4
+
5
+ AudioDict = Dict[str, torch.Tensor]
6
+
7
+ DataDict = TypedDict('DataDict', {'audio': AudioDict, 'track': str})
8
+
9
+ BatchedDataDict = TypedDict(
10
+ 'BatchedDataDict',
11
+ {'audio': AudioDict, 'track': Sequence[str]}
12
+ )
13
+
14
+
15
+ class DataDictWithLanguage(TypedDict):
16
+ audio: AudioDict
17
+ track: str
18
+ language: str
models/bandit/core/data/augmentation.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Any, Dict, Union
3
+
4
+ import torch
5
+ import torch_audiomentations as tam
6
+ from torch import nn
7
+
8
+ from models.bandit.core.data._types import BatchedDataDict, DataDict
9
+
10
+
11
+ class BaseAugmentor(nn.Module, ABC):
12
+ def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
13
+ DataDict, BatchedDataDict]:
14
+ raise NotImplementedError
15
+
16
+
17
+ class StemAugmentor(BaseAugmentor):
18
+ def __init__(
19
+ self,
20
+ audiomentations: Dict[str, Dict[str, Any]],
21
+ fix_clipping: bool = True,
22
+ scaler_margin: float = 0.5,
23
+ apply_both_default_and_common: bool = False,
24
+ ) -> None:
25
+ super().__init__()
26
+
27
+ augmentations = {}
28
+
29
+ self.has_default = "[default]" in audiomentations
30
+ self.has_common = "[common]" in audiomentations
31
+ self.apply_both_default_and_common = apply_both_default_and_common
32
+
33
+ for stem in audiomentations:
34
+ if audiomentations[stem]["name"] == "Compose":
35
+ augmentations[stem] = getattr(
36
+ tam,
37
+ audiomentations[stem]["name"]
38
+ )(
39
+ [
40
+ getattr(tam, aug["name"])(**aug["kwargs"])
41
+ for aug in
42
+ audiomentations[stem]["kwargs"]["transforms"]
43
+ ],
44
+ **audiomentations[stem]["kwargs"]["kwargs"],
45
+ )
46
+ else:
47
+ augmentations[stem] = getattr(
48
+ tam,
49
+ audiomentations[stem]["name"]
50
+ )(
51
+ **audiomentations[stem]["kwargs"]
52
+ )
53
+
54
+ self.augmentations = nn.ModuleDict(augmentations)
55
+ self.fix_clipping = fix_clipping
56
+ self.scaler_margin = scaler_margin
57
+
58
+ def check_and_fix_clipping(
59
+ self, item: Union[DataDict, BatchedDataDict]
60
+ ) -> Union[DataDict, BatchedDataDict]:
61
+ max_abs = []
62
+
63
+ for stem in item["audio"]:
64
+ max_abs.append(item["audio"][stem].abs().max().item())
65
+
66
+ if max(max_abs) > 1.0:
67
+ scaler = 1.0 / (max(max_abs) + torch.rand(
68
+ (1,),
69
+ device=item["audio"]["mixture"].device
70
+ ) * self.scaler_margin)
71
+
72
+ for stem in item["audio"]:
73
+ item["audio"][stem] *= scaler
74
+
75
+ return item
76
+
77
+ def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
78
+ DataDict, BatchedDataDict]:
79
+
80
+ for stem in item["audio"]:
81
+ if stem == "mixture":
82
+ continue
83
+
84
+ if self.has_common:
85
+ item["audio"][stem] = self.augmentations["[common]"](
86
+ item["audio"][stem]
87
+ ).samples
88
+
89
+ if stem in self.augmentations:
90
+ item["audio"][stem] = self.augmentations[stem](
91
+ item["audio"][stem]
92
+ ).samples
93
+ elif self.has_default:
94
+ if not self.has_common or self.apply_both_default_and_common:
95
+ item["audio"][stem] = self.augmentations["[default]"](
96
+ item["audio"][stem]
97
+ ).samples
98
+
99
+ item["audio"]["mixture"] = sum(
100
+ [item["audio"][stem] for stem in item["audio"]
101
+ if stem != "mixture"]
102
+ ) # type: ignore[call-overload, assignment]
103
+
104
+ if self.fix_clipping:
105
+ item = self.check_and_fix_clipping(item)
106
+
107
+ return item
models/bandit/core/data/augmented.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Dict, Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils import data
7
+
8
+
9
+ class AugmentedDataset(data.Dataset):
10
+ def __init__(
11
+ self,
12
+ dataset: data.Dataset,
13
+ augmentation: nn.Module = nn.Identity(),
14
+ target_length: Optional[int] = None,
15
+ ) -> None:
16
+ warnings.warn(
17
+ "This class is no longer used. Attach augmentation to "
18
+ "the LightningSystem instead.",
19
+ DeprecationWarning,
20
+ )
21
+
22
+ self.dataset = dataset
23
+ self.augmentation = augmentation
24
+
25
+ self.ds_length: int = len(dataset) # type: ignore[arg-type]
26
+ self.length = target_length if target_length is not None else self.ds_length
27
+
28
+ def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str,
29
+ torch.Tensor]]]:
30
+ item = self.dataset[index % self.ds_length]
31
+ item = self.augmentation(item)
32
+ return item
33
+
34
+ def __len__(self) -> int:
35
+ return self.length
models/bandit/core/data/base.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pedalboard as pb
7
+ import torch
8
+ import torchaudio as ta
9
+ from torch.utils import data
10
+
11
+ from models.bandit.core.data._types import AudioDict, DataDict
12
+
13
+
14
+ class BaseSourceSeparationDataset(data.Dataset, ABC):
15
+ def __init__(
16
+ self, split: str,
17
+ stems: List[str],
18
+ files: List[str],
19
+ data_path: str,
20
+ fs: int,
21
+ npy_memmap: bool,
22
+ recompute_mixture: bool
23
+ ):
24
+ self.split = split
25
+ self.stems = stems
26
+ self.stems_no_mixture = [s for s in stems if s != "mixture"]
27
+ self.files = files
28
+ self.data_path = data_path
29
+ self.fs = fs
30
+ self.npy_memmap = npy_memmap
31
+ self.recompute_mixture = recompute_mixture
32
+
33
+ @abstractmethod
34
+ def get_stem(
35
+ self,
36
+ *,
37
+ stem: str,
38
+ identifier: Dict[str, Any]
39
+ ) -> torch.Tensor:
40
+ raise NotImplementedError
41
+
42
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
43
+ audio = {}
44
+ for stem in stems:
45
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier)
46
+
47
+ return audio
48
+
49
+ def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
50
+
51
+ if self.recompute_mixture:
52
+ audio = self._get_audio(
53
+ self.stems_no_mixture,
54
+ identifier=identifier
55
+ )
56
+ audio["mixture"] = self.compute_mixture(audio)
57
+ return audio
58
+ else:
59
+ return self._get_audio(self.stems, identifier=identifier)
60
+
61
+ @abstractmethod
62
+ def get_identifier(self, index: int) -> Dict[str, Any]:
63
+ pass
64
+
65
+ def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
66
+
67
+ return sum(
68
+ audio[stem] for stem in audio if stem != "mixture"
69
+ )
models/bandit/core/data/dnr/__init__.py ADDED
File without changes
models/bandit/core/data/dnr/datamodule.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from .dataset import (
7
+ DivideAndRemasterDataset,
8
+ DivideAndRemasterDeterministicChunkDataset,
9
+ DivideAndRemasterRandomChunkDataset,
10
+ DivideAndRemasterRandomChunkDatasetWithSpeechReverb
11
+ )
12
+
13
+
14
+ def DivideAndRemasterDataModule(
15
+ data_root: str = "$DATA_ROOT/DnR/v2",
16
+ batch_size: int = 2,
17
+ num_workers: int = 8,
18
+ train_kwargs: Optional[Mapping] = None,
19
+ val_kwargs: Optional[Mapping] = None,
20
+ test_kwargs: Optional[Mapping] = None,
21
+ datamodule_kwargs: Optional[Mapping] = None,
22
+ use_speech_reverb: bool = False
23
+ # augmentor=None
24
+ ) -> pl.LightningDataModule:
25
+ if train_kwargs is None:
26
+ train_kwargs = {}
27
+
28
+ if val_kwargs is None:
29
+ val_kwargs = {}
30
+
31
+ if test_kwargs is None:
32
+ test_kwargs = {}
33
+
34
+ if datamodule_kwargs is None:
35
+ datamodule_kwargs = {}
36
+
37
+ if num_workers is None:
38
+ num_workers = os.cpu_count()
39
+
40
+ if num_workers is None:
41
+ num_workers = 32
42
+
43
+ num_workers = min(num_workers, 64)
44
+
45
+ if use_speech_reverb:
46
+ train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
47
+ else:
48
+ train_cls = DivideAndRemasterRandomChunkDataset
49
+
50
+ train_dataset = train_cls(
51
+ data_root, "train", **train_kwargs
52
+ )
53
+
54
+ # if augmentor is not None:
55
+ # train_dataset = AugmentedDataset(train_dataset, augmentor)
56
+
57
+ datamodule = pl.LightningDataModule.from_datasets(
58
+ train_dataset=train_dataset,
59
+ val_dataset=DivideAndRemasterDeterministicChunkDataset(
60
+ data_root, "val", **val_kwargs
61
+ ),
62
+ test_dataset=DivideAndRemasterDataset(
63
+ data_root,
64
+ "test",
65
+ **test_kwargs
66
+ ),
67
+ batch_size=batch_size,
68
+ num_workers=num_workers,
69
+ **datamodule_kwargs
70
+ )
71
+
72
+ datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
73
+
74
+ return datamodule
models/bandit/core/data/dnr/dataset.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pedalboard as pb
7
+ import torch
8
+ import torchaudio as ta
9
+ from torch.utils import data
10
+
11
+ from models.bandit.core.data._types import AudioDict, DataDict
12
+ from models.bandit.core.data.base import BaseSourceSeparationDataset
13
+
14
+
15
+ class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
16
+ ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
17
+ STEM_NAME_MAP = {
18
+ "mixture": "mix",
19
+ "speech": "speech",
20
+ "music": "music",
21
+ "effects": "sfx",
22
+ }
23
+ SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
24
+
25
+ FULL_TRACK_LENGTH_SECOND = 60
26
+ FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
27
+
28
+ def __init__(
29
+ self,
30
+ split: str,
31
+ stems: List[str],
32
+ files: List[str],
33
+ data_path: str,
34
+ fs: int = 44100,
35
+ npy_memmap: bool = True,
36
+ recompute_mixture: bool = False,
37
+ ) -> None:
38
+ super().__init__(
39
+ split=split,
40
+ stems=stems,
41
+ files=files,
42
+ data_path=data_path,
43
+ fs=fs,
44
+ npy_memmap=npy_memmap,
45
+ recompute_mixture=recompute_mixture
46
+ )
47
+
48
+ def get_stem(
49
+ self,
50
+ *,
51
+ stem: str,
52
+ identifier: Dict[str, Any]
53
+ ) -> torch.Tensor:
54
+
55
+ if stem == "mne":
56
+ return self.get_stem(
57
+ stem="music",
58
+ identifier=identifier) + self.get_stem(
59
+ stem="effects",
60
+ identifier=identifier)
61
+
62
+ track = identifier["track"]
63
+ path = os.path.join(self.data_path, track)
64
+
65
+ if self.npy_memmap:
66
+ audio = np.load(
67
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"),
68
+ mmap_mode="r"
69
+ )
70
+ else:
71
+ # noinspection PyUnresolvedReferences
72
+ audio, _ = ta.load(
73
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav")
74
+ )
75
+
76
+ return audio
77
+
78
+ def get_identifier(self, index):
79
+ return dict(track=self.files[index])
80
+
81
+ def __getitem__(self, index: int) -> DataDict:
82
+ identifier = self.get_identifier(index)
83
+ audio = self.get_audio(identifier)
84
+
85
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
86
+
87
+
88
+ class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
89
+ def __init__(
90
+ self,
91
+ data_root: str,
92
+ split: str,
93
+ stems: Optional[List[str]] = None,
94
+ fs: int = 44100,
95
+ npy_memmap: bool = True,
96
+ ) -> None:
97
+
98
+ if stems is None:
99
+ stems = self.ALLOWED_STEMS
100
+ self.stems = stems
101
+
102
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
103
+
104
+ files = sorted(os.listdir(data_path))
105
+ files = [
106
+ f
107
+ for f in files
108
+ if (not f.startswith(".")) and os.path.isdir(
109
+ os.path.join(data_path, f)
110
+ )
111
+ ]
112
+ # pprint(list(enumerate(files)))
113
+ if split == "train":
114
+ assert len(files) == 3406, len(files)
115
+ elif split == "val":
116
+ assert len(files) == 487, len(files)
117
+ elif split == "test":
118
+ assert len(files) == 973, len(files)
119
+
120
+ self.n_tracks = len(files)
121
+
122
+ super().__init__(
123
+ data_path=data_path,
124
+ split=split,
125
+ stems=stems,
126
+ files=files,
127
+ fs=fs,
128
+ npy_memmap=npy_memmap,
129
+ )
130
+
131
+ def __len__(self) -> int:
132
+ return self.n_tracks
133
+
134
+
135
+ class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
136
+ def __init__(
137
+ self,
138
+ data_root: str,
139
+ split: str,
140
+ target_length: int,
141
+ chunk_size_second: float,
142
+ stems: Optional[List[str]] = None,
143
+ fs: int = 44100,
144
+ npy_memmap: bool = True,
145
+ ) -> None:
146
+
147
+ if stems is None:
148
+ stems = self.ALLOWED_STEMS
149
+ self.stems = stems
150
+
151
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
152
+
153
+ files = sorted(os.listdir(data_path))
154
+ files = [
155
+ f
156
+ for f in files
157
+ if (not f.startswith(".")) and os.path.isdir(
158
+ os.path.join(data_path, f)
159
+ )
160
+ ]
161
+
162
+ if split == "train":
163
+ assert len(files) == 3406, len(files)
164
+ elif split == "val":
165
+ assert len(files) == 487, len(files)
166
+ elif split == "test":
167
+ assert len(files) == 973, len(files)
168
+
169
+ self.n_tracks = len(files)
170
+
171
+ self.target_length = target_length
172
+ self.chunk_size = int(chunk_size_second * fs)
173
+
174
+ super().__init__(
175
+ data_path=data_path,
176
+ split=split,
177
+ stems=stems,
178
+ files=files,
179
+ fs=fs,
180
+ npy_memmap=npy_memmap,
181
+ )
182
+
183
+ def __len__(self) -> int:
184
+ return self.target_length
185
+
186
+ def get_identifier(self, index):
187
+ return super().get_identifier(index % self.n_tracks)
188
+
189
+ def get_stem(
190
+ self,
191
+ *,
192
+ stem: str,
193
+ identifier: Dict[str, Any],
194
+ chunk_here: bool = False,
195
+ ) -> torch.Tensor:
196
+
197
+ stem = super().get_stem(
198
+ stem=stem,
199
+ identifier=identifier
200
+ )
201
+
202
+ if chunk_here:
203
+ start = np.random.randint(
204
+ 0,
205
+ self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
206
+ )
207
+ end = start + self.chunk_size
208
+
209
+ stem = stem[:, start:end]
210
+
211
+ return stem
212
+
213
+ def __getitem__(self, index: int) -> DataDict:
214
+ identifier = self.get_identifier(index)
215
+ # self.index_lock = index
216
+ audio = self.get_audio(identifier)
217
+ # self.index_lock = None
218
+
219
+ start = np.random.randint(
220
+ 0,
221
+ self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
222
+ )
223
+ end = start + self.chunk_size
224
+
225
+ audio = {
226
+ k: v[:, start:end] for k, v in audio.items()
227
+ }
228
+
229
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
230
+
231
+
232
+ class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
233
+ def __init__(
234
+ self,
235
+ data_root: str,
236
+ split: str,
237
+ chunk_size_second: float,
238
+ hop_size_second: float,
239
+ stems: Optional[List[str]] = None,
240
+ fs: int = 44100,
241
+ npy_memmap: bool = True,
242
+ ) -> None:
243
+
244
+ if stems is None:
245
+ stems = self.ALLOWED_STEMS
246
+ self.stems = stems
247
+
248
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
249
+
250
+ files = sorted(os.listdir(data_path))
251
+ files = [
252
+ f
253
+ for f in files
254
+ if (not f.startswith(".")) and os.path.isdir(
255
+ os.path.join(data_path, f)
256
+ )
257
+ ]
258
+ # pprint(list(enumerate(files)))
259
+ if split == "train":
260
+ assert len(files) == 3406, len(files)
261
+ elif split == "val":
262
+ assert len(files) == 487, len(files)
263
+ elif split == "test":
264
+ assert len(files) == 973, len(files)
265
+
266
+ self.n_tracks = len(files)
267
+
268
+ self.chunk_size = int(chunk_size_second * fs)
269
+ self.hop_size = int(hop_size_second * fs)
270
+ self.n_chunks_per_track = int(
271
+ (
272
+ self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
273
+ )
274
+
275
+ self.length = self.n_tracks * self.n_chunks_per_track
276
+
277
+ super().__init__(
278
+ data_path=data_path,
279
+ split=split,
280
+ stems=stems,
281
+ files=files,
282
+ fs=fs,
283
+ npy_memmap=npy_memmap,
284
+ )
285
+
286
+ def get_identifier(self, index):
287
+ return super().get_identifier(index % self.n_tracks)
288
+
289
+ def __len__(self) -> int:
290
+ return self.length
291
+
292
+ def __getitem__(self, item: int) -> DataDict:
293
+
294
+ index = item % self.n_tracks
295
+ chunk = item // self.n_tracks
296
+
297
+ data_ = super().__getitem__(index)
298
+
299
+ audio = data_["audio"]
300
+
301
+ start = chunk * self.hop_size
302
+ end = start + self.chunk_size
303
+
304
+ for stem in self.stems:
305
+ data_["audio"][stem] = audio[stem][:, start:end]
306
+
307
+ return data_
308
+
309
+
310
+ class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
311
+ DivideAndRemasterRandomChunkDataset
312
+ ):
313
+ def __init__(
314
+ self,
315
+ data_root: str,
316
+ split: str,
317
+ target_length: int,
318
+ chunk_size_second: float,
319
+ stems: Optional[List[str]] = None,
320
+ fs: int = 44100,
321
+ npy_memmap: bool = True,
322
+ ) -> None:
323
+
324
+ if stems is None:
325
+ stems = self.ALLOWED_STEMS
326
+
327
+ stems_no_mixture = [s for s in stems if s != "mixture"]
328
+
329
+ super().__init__(
330
+ data_root=data_root,
331
+ split=split,
332
+ target_length=target_length,
333
+ chunk_size_second=chunk_size_second,
334
+ stems=stems_no_mixture,
335
+ fs=fs,
336
+ npy_memmap=npy_memmap,
337
+ )
338
+
339
+ self.stems = stems
340
+ self.stems_no_mixture = stems_no_mixture
341
+
342
+ def __getitem__(self, index: int) -> DataDict:
343
+
344
+ data_ = super().__getitem__(index)
345
+
346
+ dry = data_["audio"]["speech"][:]
347
+ n_samples = dry.shape[-1]
348
+
349
+ wet_level = np.random.rand()
350
+
351
+ speech = pb.Reverb(
352
+ room_size=np.random.rand(),
353
+ damping=np.random.rand(),
354
+ wet_level=wet_level,
355
+ dry_level=(1 - wet_level),
356
+ width=np.random.rand()
357
+ ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
358
+
359
+ data_["audio"]["speech"] = speech
360
+
361
+ data_["audio"]["mixture"] = sum(
362
+ [data_["audio"][s] for s in self.stems_no_mixture]
363
+ )
364
+
365
+ return data_
366
+
367
+ def __len__(self) -> int:
368
+ return super().__len__()
369
+
370
+
371
+ if __name__ == "__main__":
372
+
373
+ from pprint import pprint
374
+ from tqdm import tqdm
375
+
376
+ for split_ in ["train", "val", "test"]:
377
+ ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
378
+ data_root="$DATA_ROOT/DnR/v2np",
379
+ split=split_,
380
+ target_length=100,
381
+ chunk_size_second=6.0
382
+ )
383
+
384
+ print(split_, len(ds))
385
+
386
+ for track_ in tqdm(ds): # type: ignore
387
+ pprint(track_)
388
+ track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
389
+ pprint(track_)
390
+ # break
391
+
392
+ break
models/bandit/core/data/dnr/preprocess.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ import torchaudio as ta
7
+ from tqdm.contrib.concurrent import process_map
8
+
9
+
10
+ def process_one(inputs: Tuple[str, str, int]) -> None:
11
+ infile, outfile, target_fs = inputs
12
+
13
+ dir = os.path.dirname(outfile)
14
+ os.makedirs(dir, exist_ok=True)
15
+
16
+ data, fs = ta.load(infile)
17
+
18
+ if fs != target_fs:
19
+ data = ta.functional.resample(data, fs, target_fs, resampling_method="sinc_interp_kaiser")
20
+ fs = target_fs
21
+
22
+ data = data.numpy()
23
+ data = data.astype(np.float32)
24
+
25
+ if os.path.exists(outfile):
26
+ data_ = np.load(outfile)
27
+ if np.allclose(data, data_):
28
+ return
29
+
30
+ np.save(outfile, data)
31
+
32
+
33
+ def preprocess(
34
+ data_path: str,
35
+ output_path: str,
36
+ fs: int
37
+ ) -> None:
38
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
39
+ print(files)
40
+ outfiles = [
41
+ f.replace(data_path, output_path).replace(".wav", ".npy") for f in
42
+ files
43
+ ]
44
+
45
+ os.makedirs(output_path, exist_ok=True)
46
+ inputs = list(zip(files, outfiles, [fs] * len(files)))
47
+
48
+ process_map(process_one, inputs, chunksize=32)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ import fire
53
+
54
+ fire.Fire()
models/bandit/core/data/musdb/__init__.py ADDED
File without changes
models/bandit/core/data/musdb/datamodule.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from models.bandit.core.data.musdb.dataset import (
7
+ MUSDB18BaseDataset,
8
+ MUSDB18FullTrackDataset,
9
+ MUSDB18SadDataset,
10
+ MUSDB18SadOnTheFlyAugmentedDataset
11
+ )
12
+
13
+
14
+ def MUSDB18DataModule(
15
+ data_root: str = "$DATA_ROOT/MUSDB18/HQ",
16
+ target_stem: str = "vocals",
17
+ batch_size: int = 2,
18
+ num_workers: int = 8,
19
+ train_kwargs: Optional[Mapping] = None,
20
+ val_kwargs: Optional[Mapping] = None,
21
+ test_kwargs: Optional[Mapping] = None,
22
+ datamodule_kwargs: Optional[Mapping] = None,
23
+ use_on_the_fly: bool = True,
24
+ npy_memmap: bool = True
25
+ ) -> pl.LightningDataModule:
26
+ if train_kwargs is None:
27
+ train_kwargs = {}
28
+
29
+ if val_kwargs is None:
30
+ val_kwargs = {}
31
+
32
+ if test_kwargs is None:
33
+ test_kwargs = {}
34
+
35
+ if datamodule_kwargs is None:
36
+ datamodule_kwargs = {}
37
+
38
+ train_dataset: MUSDB18BaseDataset
39
+
40
+ if use_on_the_fly:
41
+ train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
42
+ data_root=os.path.join(data_root, "saded-np"),
43
+ split="train",
44
+ target_stem=target_stem,
45
+ **train_kwargs
46
+ )
47
+ else:
48
+ train_dataset = MUSDB18SadDataset(
49
+ data_root=os.path.join(data_root, "saded-np"),
50
+ split="train",
51
+ target_stem=target_stem,
52
+ **train_kwargs
53
+ )
54
+
55
+ datamodule = pl.LightningDataModule.from_datasets(
56
+ train_dataset=train_dataset,
57
+ val_dataset=MUSDB18SadDataset(
58
+ data_root=os.path.join(data_root, "saded-np"),
59
+ split="val",
60
+ target_stem=target_stem,
61
+ **val_kwargs
62
+ ),
63
+ test_dataset=MUSDB18FullTrackDataset(
64
+ data_root=os.path.join(data_root, "canonical"),
65
+ split="test",
66
+ **test_kwargs
67
+ ),
68
+ batch_size=batch_size,
69
+ num_workers=num_workers,
70
+ **datamodule_kwargs
71
+ )
72
+
73
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
74
+ datamodule.test_dataloader
75
+ )
76
+
77
+ return datamodule
models/bandit/core/data/musdb/dataset.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio as ta
8
+ from torch.utils import data
9
+
10
+ from models.bandit.core.data._types import AudioDict, DataDict
11
+ from models.bandit.core.data.base import BaseSourceSeparationDataset
12
+
13
+
14
+ class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
15
+
16
+ ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
17
+
18
+ def __init__(
19
+ self,
20
+ split: str,
21
+ stems: List[str],
22
+ files: List[str],
23
+ data_path: str,
24
+ fs: int = 44100,
25
+ npy_memmap=False,
26
+ ) -> None:
27
+ super().__init__(
28
+ split=split,
29
+ stems=stems,
30
+ files=files,
31
+ data_path=data_path,
32
+ fs=fs,
33
+ npy_memmap=npy_memmap,
34
+ recompute_mixture=False
35
+ )
36
+
37
+ def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
38
+ track = identifier["track"]
39
+ path = os.path.join(self.data_path, track)
40
+ # noinspection PyUnresolvedReferences
41
+
42
+ if self.npy_memmap:
43
+ audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
44
+ else:
45
+ audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
46
+
47
+ return audio
48
+
49
+ def get_identifier(self, index):
50
+ return dict(track=self.files[index])
51
+
52
+ def __getitem__(self, index: int) -> DataDict:
53
+ identifier = self.get_identifier(index)
54
+ audio = self.get_audio(identifier)
55
+
56
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
57
+
58
+
59
+ class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
60
+
61
+ N_TRAIN_TRACKS = 100
62
+ N_TEST_TRACKS = 50
63
+ VALIDATION_FILES = [
64
+ "Actions - One Minute Smile",
65
+ "Clara Berry And Wooldog - Waltz For My Victims",
66
+ "Johnny Lokke - Promises & Lies",
67
+ "Patrick Talbot - A Reason To Leave",
68
+ "Triviul - Angelsaint",
69
+ "Alexander Ross - Goodbye Bolero",
70
+ "Fergessen - Nos Palpitants",
71
+ "Leaf - Summerghost",
72
+ "Skelpolu - Human Mistakes",
73
+ "Young Griffo - Pennies",
74
+ "ANiMAL - Rockshow",
75
+ "James May - On The Line",
76
+ "Meaxic - Take A Step",
77
+ "Traffic Experiment - Sirens",
78
+ ]
79
+
80
+ def __init__(
81
+ self, data_root: str, split: str, stems: Optional[List[
82
+ str]] = None
83
+ ) -> None:
84
+
85
+ if stems is None:
86
+ stems = self.ALLOWED_STEMS
87
+ self.stems = stems
88
+
89
+ if split == "test":
90
+ subset = "test"
91
+ elif split in ["train", "val"]:
92
+ subset = "train"
93
+ else:
94
+ raise NameError
95
+
96
+ data_path = os.path.join(data_root, subset)
97
+
98
+ files = sorted(os.listdir(data_path))
99
+ files = [f for f in files if not f.startswith(".")]
100
+ # pprint(list(enumerate(files)))
101
+ if subset == "train":
102
+ assert len(files) == 100, len(files)
103
+ if split == "train":
104
+ files = [f for f in files if f not in self.VALIDATION_FILES]
105
+ assert len(files) == 100 - len(self.VALIDATION_FILES)
106
+ else:
107
+ files = [f for f in files if f in self.VALIDATION_FILES]
108
+ assert len(files) == len(self.VALIDATION_FILES)
109
+ else:
110
+ split = "test"
111
+ assert len(files) == 50
112
+
113
+ self.n_tracks = len(files)
114
+
115
+ super().__init__(
116
+ data_path=data_path,
117
+ split=split,
118
+ stems=stems,
119
+ files=files
120
+ )
121
+
122
+ def __len__(self) -> int:
123
+ return self.n_tracks
124
+
125
+ class MUSDB18SadDataset(MUSDB18BaseDataset):
126
+ def __init__(
127
+ self,
128
+ data_root: str,
129
+ split: str,
130
+ target_stem: str,
131
+ stems: Optional[List[str]] = None,
132
+ target_length: Optional[int] = None,
133
+ npy_memmap=False,
134
+ ) -> None:
135
+
136
+ if stems is None:
137
+ stems = self.ALLOWED_STEMS
138
+
139
+ data_path = os.path.join(data_root, target_stem, split)
140
+
141
+ files = sorted(os.listdir(data_path))
142
+ files = [f for f in files if not f.startswith(".")]
143
+
144
+ super().__init__(
145
+ data_path=data_path,
146
+ split=split,
147
+ stems=stems,
148
+ files=files,
149
+ npy_memmap=npy_memmap
150
+ )
151
+ self.n_segments = len(files)
152
+ self.target_stem = target_stem
153
+ self.target_length = (
154
+ target_length if target_length is not None else self.n_segments
155
+ )
156
+
157
+ def __len__(self) -> int:
158
+ return self.target_length
159
+
160
+ def __getitem__(self, index: int) -> DataDict:
161
+
162
+ index = index % self.n_segments
163
+
164
+ return super().__getitem__(index)
165
+
166
+ def get_identifier(self, index):
167
+ return super().get_identifier(index % self.n_segments)
168
+
169
+
170
+ class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
171
+ def __init__(
172
+ self,
173
+ data_root: str,
174
+ split: str,
175
+ target_stem: str,
176
+ stems: Optional[List[str]] = None,
177
+ target_length: int = 20000,
178
+ apply_probability: Optional[float] = None,
179
+ chunk_size_second: float = 3.0,
180
+ random_scale_range_db: Tuple[float, float] = (-10, 10),
181
+ drop_probability: float = 0.1,
182
+ rescale: bool = True,
183
+ ) -> None:
184
+ super().__init__(data_root, split, target_stem, stems)
185
+
186
+ if apply_probability is None:
187
+ apply_probability = (
188
+ target_length - self.n_segments) / target_length
189
+
190
+ self.apply_probability = apply_probability
191
+ self.drop_probability = drop_probability
192
+ self.chunk_size_second = chunk_size_second
193
+ self.random_scale_range_db = random_scale_range_db
194
+ self.rescale = rescale
195
+
196
+ self.chunk_size_sample = int(self.chunk_size_second * self.fs)
197
+ self.target_length = target_length
198
+
199
+ def __len__(self) -> int:
200
+ return self.target_length
201
+
202
+ def __getitem__(self, index: int) -> DataDict:
203
+
204
+ index = index % self.n_segments
205
+
206
+ # if np.random.rand() > self.apply_probability:
207
+ # return super().__getitem__(index)
208
+
209
+ audio = {}
210
+ identifier = self.get_identifier(index)
211
+
212
+ # assert self.target_stem in self.stems_no_mixture
213
+ for stem in self.stems_no_mixture:
214
+ if stem == self.target_stem:
215
+ identifier_ = identifier
216
+ else:
217
+ if np.random.rand() < self.apply_probability:
218
+ index_ = np.random.randint(self.n_segments)
219
+ identifier_ = self.get_identifier(index_)
220
+ else:
221
+ identifier_ = identifier
222
+
223
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
224
+
225
+ # if stem == self.target_stem:
226
+
227
+ if self.chunk_size_sample < audio[stem].shape[-1]:
228
+ chunk_start = np.random.randint(
229
+ audio[stem].shape[-1] - self.chunk_size_sample
230
+ )
231
+ else:
232
+ chunk_start = 0
233
+
234
+ if np.random.rand() < self.drop_probability:
235
+ # db_scale = "-inf"
236
+ linear_scale = 0.0
237
+ else:
238
+ db_scale = np.random.uniform(*self.random_scale_range_db)
239
+ linear_scale = np.power(10, db_scale / 20)
240
+ # db_scale = f"{db_scale:+2.1f}"
241
+ # print(linear_scale)
242
+ audio[stem][...,
243
+ chunk_start: chunk_start + self.chunk_size_sample] = (
244
+ linear_scale
245
+ * audio[stem][...,
246
+ chunk_start: chunk_start + self.chunk_size_sample]
247
+ )
248
+
249
+ audio["mixture"] = self.compute_mixture(audio)
250
+
251
+ if self.rescale:
252
+ max_abs_val = max(
253
+ [torch.max(torch.abs(audio[stem])) for stem in self.stems]
254
+ ) # type: ignore[type-var]
255
+ if max_abs_val > 1:
256
+ audio = {k: v / max_abs_val for k, v in audio.items()}
257
+
258
+ track = identifier["track"]
259
+
260
+ return {"audio": audio, "track": f"{self.split}/{track}"}
261
+
262
+ # if __name__ == "__main__":
263
+ #
264
+ # from pprint import pprint
265
+ # from tqdm import tqdm
266
+ #
267
+ # for split_ in ["train", "val", "test"]:
268
+ # ds = MUSDB18SadOnTheFlyAugmentedDataset(
269
+ # data_root="$DATA_ROOT/MUSDB18/HQ/saded",
270
+ # split=split_,
271
+ # target_stem="vocals"
272
+ # )
273
+ #
274
+ # print(split_, len(ds))
275
+ #
276
+ # for track_ in tqdm(ds):
277
+ # track_["audio"] = {
278
+ # k: v.shape for k, v in track_["audio"].items()
279
+ # }
280
+ # pprint(track_)
models/bandit/core/data/musdb/preprocess.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio as ta
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from tqdm.contrib.concurrent import process_map
10
+
11
+ from core.data._types import DataDict
12
+ from core.data.musdb.dataset import MUSDB18FullTrackDataset
13
+ import pyloudnorm as pyln
14
+
15
+ class SourceActivityDetector(nn.Module):
16
+ def __init__(
17
+ self,
18
+ analysis_stem: str,
19
+ output_path: str,
20
+ fs: int = 44100,
21
+ segment_length_second: float = 6.0,
22
+ hop_length_second: float = 3.0,
23
+ n_chunks: int = 10,
24
+ chunk_epsilon: float = 1e-5,
25
+ energy_threshold_quantile: float = 0.15,
26
+ segment_epsilon: float = 1e-3,
27
+ salient_proportion_threshold: float = 0.5,
28
+ target_lufs: float = -24
29
+ ) -> None:
30
+ super().__init__()
31
+
32
+ self.fs = fs
33
+ self.segment_length = int(segment_length_second * self.fs)
34
+ self.hop_length = int(hop_length_second * self.fs)
35
+ self.n_chunks = n_chunks
36
+ assert self.segment_length % self.n_chunks == 0
37
+ self.chunk_size = self.segment_length // self.n_chunks
38
+ self.chunk_epsilon = chunk_epsilon
39
+ self.energy_threshold_quantile = energy_threshold_quantile
40
+ self.segment_epsilon = segment_epsilon
41
+ self.salient_proportion_threshold = salient_proportion_threshold
42
+ self.analysis_stem = analysis_stem
43
+
44
+ self.meter = pyln.Meter(self.fs)
45
+ self.target_lufs = target_lufs
46
+
47
+ self.output_path = output_path
48
+
49
+ def forward(self, data: DataDict) -> None:
50
+
51
+ stem_ = self.analysis_stem if (
52
+ self.analysis_stem != "none") else "mixture"
53
+
54
+ x = data["audio"][stem_]
55
+
56
+ xnp = x.numpy()
57
+ loudness = self.meter.integrated_loudness(xnp.T)
58
+
59
+ for stem in data["audio"]:
60
+ s = data["audio"][stem]
61
+ s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
62
+ s = torch.as_tensor(s)
63
+ data["audio"][stem] = s
64
+
65
+ if x.ndim == 3:
66
+ assert x.shape[0] == 1
67
+ x = x[0]
68
+
69
+ n_chan, n_samples = x.shape
70
+
71
+ n_segments = (
72
+ int(
73
+ np.ceil((n_samples - self.segment_length) / self.hop_length)
74
+ ) + 1
75
+ )
76
+
77
+ segments = torch.zeros((n_segments, n_chan, self.segment_length))
78
+ for i in range(n_segments):
79
+ start = i * self.hop_length
80
+ end = start + self.segment_length
81
+ end = min(end, n_samples)
82
+
83
+ xseg = x[:, start:end]
84
+
85
+ if end - start < self.segment_length:
86
+ xseg = F.pad(
87
+ xseg,
88
+ pad=(0, self.segment_length - (end - start)),
89
+ value=torch.nan
90
+ )
91
+
92
+ segments[i, :, :] = xseg
93
+
94
+ chunks = segments.reshape(
95
+ (n_segments, n_chan, self.n_chunks, self.chunk_size)
96
+ )
97
+
98
+ if self.analysis_stem != "none":
99
+ chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
100
+ chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
101
+ chunk_energies[chunk_energies == 0] = self.chunk_epsilon
102
+
103
+ energy_threshold = torch.nanquantile(
104
+ chunk_energies, q=self.energy_threshold_quantile
105
+ )
106
+
107
+ if energy_threshold < self.segment_epsilon:
108
+ energy_threshold = self.segment_epsilon # type: ignore[assignment]
109
+
110
+ chunks_above_threshold = chunk_energies > energy_threshold
111
+ n_chunks_above_threshold = torch.mean(
112
+ chunks_above_threshold.to(torch.float), dim=-1
113
+ )
114
+
115
+ segment_above_threshold = (
116
+ n_chunks_above_threshold > self.salient_proportion_threshold
117
+ )
118
+
119
+ if torch.sum(segment_above_threshold) == 0:
120
+ return
121
+
122
+ else:
123
+ segment_above_threshold = torch.ones((n_segments,))
124
+
125
+ for i in range(n_segments):
126
+ if not segment_above_threshold[i]:
127
+ continue
128
+
129
+ outpath = os.path.join(
130
+ self.output_path,
131
+ self.analysis_stem,
132
+ f"{data['track']} - {self.analysis_stem}{i:03d}",
133
+ )
134
+ os.makedirs(outpath, exist_ok=True)
135
+
136
+ for stem in data["audio"]:
137
+ if stem == self.analysis_stem:
138
+ segment = torch.nan_to_num(segments[i, :, :], nan=0)
139
+ else:
140
+ start = i * self.hop_length
141
+ end = start + self.segment_length
142
+ end = min(n_samples, end)
143
+
144
+ segment = data["audio"][stem][:, start:end]
145
+
146
+ if end - start < self.segment_length:
147
+ segment = F.pad(
148
+ segment,
149
+ (0, self.segment_length - (end - start))
150
+ )
151
+
152
+ assert segment.shape[-1] == self.segment_length, segment.shape
153
+
154
+ # ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs)
155
+
156
+ np.save(os.path.join(outpath, f"{stem}.wav"), segment)
157
+
158
+
159
+ def preprocess(
160
+ analysis_stem: str,
161
+ output_path: str = "/data/MUSDB18/HQ/saded-np",
162
+ fs: int = 44100,
163
+ segment_length_second: float = 6.0,
164
+ hop_length_second: float = 3.0,
165
+ n_chunks: int = 10,
166
+ chunk_epsilon: float = 1e-5,
167
+ energy_threshold_quantile: float = 0.15,
168
+ segment_epsilon: float = 1e-3,
169
+ salient_proportion_threshold: float = 0.5,
170
+ ) -> None:
171
+
172
+ sad = SourceActivityDetector(
173
+ analysis_stem=analysis_stem,
174
+ output_path=output_path,
175
+ fs=fs,
176
+ segment_length_second=segment_length_second,
177
+ hop_length_second=hop_length_second,
178
+ n_chunks=n_chunks,
179
+ chunk_epsilon=chunk_epsilon,
180
+ energy_threshold_quantile=energy_threshold_quantile,
181
+ segment_epsilon=segment_epsilon,
182
+ salient_proportion_threshold=salient_proportion_threshold,
183
+ )
184
+
185
+ for split in ["train", "val", "test"]:
186
+ ds = MUSDB18FullTrackDataset(
187
+ data_root="/data/MUSDB18/HQ/canonical",
188
+ split=split,
189
+ )
190
+
191
+ tracks = []
192
+ for i, track in enumerate(tqdm(ds, total=len(ds))):
193
+ if i % 32 == 0 and tracks:
194
+ process_map(sad, tracks, max_workers=8)
195
+ tracks = []
196
+ tracks.append(track)
197
+ process_map(sad, tracks, max_workers=8)
198
+
199
+ def loudness_norm_one(
200
+ inputs
201
+ ):
202
+ infile, outfile, target_lufs = inputs
203
+
204
+ audio, fs = ta.load(infile)
205
+ audio = audio.mean(dim=0, keepdim=True).numpy().T
206
+
207
+ meter = pyln.Meter(fs)
208
+ loudness = meter.integrated_loudness(audio)
209
+ audio = pyln.normalize.loudness(audio, loudness, target_lufs)
210
+
211
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
212
+ np.save(outfile, audio.T)
213
+
214
+ def loudness_norm(
215
+ data_path: str,
216
+ # output_path: str,
217
+ target_lufs = -17.0,
218
+ ):
219
+ files = glob.glob(
220
+ os.path.join(data_path, "**", "*.wav"), recursive=True
221
+ )
222
+
223
+ outfiles = [
224
+ f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files
225
+ ]
226
+
227
+ files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
228
+
229
+ process_map(loudness_norm_one, files, chunksize=2)
230
+
231
+
232
+
233
+ if __name__ == "__main__":
234
+
235
+ from tqdm import tqdm
236
+ import fire
237
+
238
+ fire.Fire()
models/bandit/core/data/musdb/validation.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ validation:
2
+ - 'Actions - One Minute Smile'
3
+ - 'Clara Berry And Wooldog - Waltz For My Victims'
4
+ - 'Johnny Lokke - Promises & Lies'
5
+ - 'Patrick Talbot - A Reason To Leave'
6
+ - 'Triviul - Angelsaint'
7
+ - 'Alexander Ross - Goodbye Bolero'
8
+ - 'Fergessen - Nos Palpitants'
9
+ - 'Leaf - Summerghost'
10
+ - 'Skelpolu - Human Mistakes'
11
+ - 'Young Griffo - Pennies'
12
+ - 'ANiMAL - Rockshow'
13
+ - 'James May - On The Line'
14
+ - 'Meaxic - Take A Step'
15
+ - 'Traffic Experiment - Sirens'
models/bandit/core/loss/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from ._multistem import MultiStemWrapperFromConfig
2
+ from ._timefreq import ReImL1Loss, ReImL2Loss, TimeFreqL1Loss, TimeFreqL2Loss, TimeFreqSignalNoisePNormRatioLoss
models/bandit/core/loss/_complex.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules import loss as _loss
6
+ from torch.nn.modules.loss import _Loss
7
+
8
+
9
+ class ReImLossWrapper(_Loss):
10
+ def __init__(self, module: _Loss) -> None:
11
+ super().__init__()
12
+ self.module = module
13
+
14
+ def forward(
15
+ self,
16
+ preds: torch.Tensor,
17
+ target: torch.Tensor
18
+ ) -> torch.Tensor:
19
+ return self.module(
20
+ torch.view_as_real(preds),
21
+ torch.view_as_real(target)
22
+ )
23
+
24
+
25
+ class ReImL1Loss(ReImLossWrapper):
26
+ def __init__(self, **kwargs: Any) -> None:
27
+ l1_loss = _loss.L1Loss(**kwargs)
28
+ super().__init__(module=(l1_loss))
29
+
30
+
31
+ class ReImL2Loss(ReImLossWrapper):
32
+ def __init__(self, **kwargs: Any) -> None:
33
+ l2_loss = _loss.MSELoss(**kwargs)
34
+ super().__init__(module=(l2_loss))
models/bandit/core/loss/_multistem.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import torch
4
+ from asteroid import losses as asteroid_losses
5
+ from torch import nn
6
+ from torch.nn.modules.loss import _Loss
7
+
8
+ from . import snr
9
+
10
+
11
+ def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
12
+
13
+ for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
14
+ if name in module.__dict__:
15
+ return module.__dict__[name](**kwargs)
16
+
17
+ raise NameError
18
+
19
+
20
+ class MultiStemWrapper(_Loss):
21
+ def __init__(self, module: _Loss, modality: str = "audio") -> None:
22
+ super().__init__()
23
+ self.loss = module
24
+ self.modality = modality
25
+
26
+ def forward(
27
+ self,
28
+ preds: Dict[str, Dict[str, torch.Tensor]],
29
+ target: Dict[str, Dict[str, torch.Tensor]],
30
+ ) -> torch.Tensor:
31
+ loss = {
32
+ stem: self.loss(
33
+ preds[self.modality][stem],
34
+ target[self.modality][stem]
35
+ )
36
+ for stem in preds[self.modality] if stem in target[self.modality]
37
+ }
38
+
39
+ return sum(list(loss.values()))
40
+
41
+
42
+ class MultiStemWrapperFromConfig(MultiStemWrapper):
43
+ def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
44
+ loss = parse_loss(name, kwargs)
45
+ super().__init__(module=loss, modality=modality)
models/bandit/core/loss/_timefreq.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules.loss import _Loss
6
+
7
+ from models.bandit.core.loss._multistem import MultiStemWrapper
8
+ from models.bandit.core.loss._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
9
+ from models.bandit.core.loss.snr import SignalNoisePNormRatio
10
+
11
+ class TimeFreqWrapper(_Loss):
12
+ def __init__(
13
+ self,
14
+ time_module: _Loss,
15
+ freq_module: Optional[_Loss] = None,
16
+ time_weight: float = 1.0,
17
+ freq_weight: float = 1.0,
18
+ multistem: bool = True,
19
+ ) -> None:
20
+ super().__init__()
21
+
22
+ if freq_module is None:
23
+ freq_module = time_module
24
+
25
+ if multistem:
26
+ time_module = MultiStemWrapper(time_module, modality="audio")
27
+ freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
28
+
29
+ self.time_module = time_module
30
+ self.freq_module = freq_module
31
+
32
+ self.time_weight = time_weight
33
+ self.freq_weight = freq_weight
34
+
35
+ # TODO: add better type hints
36
+ def forward(self, preds: Any, target: Any) -> torch.Tensor:
37
+
38
+ return self.time_weight * self.time_module(
39
+ preds, target
40
+ ) + self.freq_weight * self.freq_module(preds, target)
41
+
42
+
43
+ class TimeFreqL1Loss(TimeFreqWrapper):
44
+ def __init__(
45
+ self,
46
+ time_weight: float = 1.0,
47
+ freq_weight: float = 1.0,
48
+ tkwargs: Optional[Dict[str, Any]] = None,
49
+ fkwargs: Optional[Dict[str, Any]] = None,
50
+ multistem: bool = True,
51
+ ) -> None:
52
+ if tkwargs is None:
53
+ tkwargs = {}
54
+ if fkwargs is None:
55
+ fkwargs = {}
56
+ time_module = (nn.L1Loss(**tkwargs))
57
+ freq_module = ReImL1Loss(**fkwargs)
58
+ super().__init__(
59
+ time_module,
60
+ freq_module,
61
+ time_weight,
62
+ freq_weight,
63
+ multistem
64
+ )
65
+
66
+
67
+ class TimeFreqL2Loss(TimeFreqWrapper):
68
+ def __init__(
69
+ self,
70
+ time_weight: float = 1.0,
71
+ freq_weight: float = 1.0,
72
+ tkwargs: Optional[Dict[str, Any]] = None,
73
+ fkwargs: Optional[Dict[str, Any]] = None,
74
+ multistem: bool = True,
75
+ ) -> None:
76
+ if tkwargs is None:
77
+ tkwargs = {}
78
+ if fkwargs is None:
79
+ fkwargs = {}
80
+ time_module = nn.MSELoss(**tkwargs)
81
+ freq_module = ReImL2Loss(**fkwargs)
82
+ super().__init__(
83
+ time_module,
84
+ freq_module,
85
+ time_weight,
86
+ freq_weight,
87
+ multistem
88
+ )
89
+
90
+
91
+
92
+ class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
93
+ def __init__(
94
+ self,
95
+ time_weight: float = 1.0,
96
+ freq_weight: float = 1.0,
97
+ tkwargs: Optional[Dict[str, Any]] = None,
98
+ fkwargs: Optional[Dict[str, Any]] = None,
99
+ multistem: bool = True,
100
+ ) -> None:
101
+ if tkwargs is None:
102
+ tkwargs = {}
103
+ if fkwargs is None:
104
+ fkwargs = {}
105
+ time_module = SignalNoisePNormRatio(**tkwargs)
106
+ freq_module = SignalNoisePNormRatio(**fkwargs)
107
+ super().__init__(
108
+ time_module,
109
+ freq_module,
110
+ time_weight,
111
+ freq_weight,
112
+ multistem
113
+ )
models/bandit/core/loss/snr.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.modules.loss import _Loss
3
+ from torch.nn import functional as F
4
+
5
+ class SignalNoisePNormRatio(_Loss):
6
+ def __init__(
7
+ self,
8
+ p: float = 1.0,
9
+ scale_invariant: bool = False,
10
+ zero_mean: bool = False,
11
+ take_log: bool = True,
12
+ reduction: str = "mean",
13
+ EPS: float = 1e-3,
14
+ ) -> None:
15
+ assert reduction != "sum", NotImplementedError
16
+ super().__init__(reduction=reduction)
17
+ assert not zero_mean
18
+
19
+ self.p = p
20
+
21
+ self.EPS = EPS
22
+ self.take_log = take_log
23
+
24
+ self.scale_invariant = scale_invariant
25
+
26
+ def forward(
27
+ self,
28
+ est_target: torch.Tensor,
29
+ target: torch.Tensor
30
+ ) -> torch.Tensor:
31
+
32
+ target_ = target
33
+ if self.scale_invariant:
34
+ ndim = target.ndim
35
+ dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
36
+ s_target_energy = (
37
+ torch.sum(target * torch.conj(target), dim=-1, keepdim=True)
38
+ )
39
+
40
+ if ndim > 2:
41
+ dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
42
+ s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True)
43
+
44
+ target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
45
+ target = target_ * target_scaler
46
+
47
+ if torch.is_complex(est_target):
48
+ est_target = torch.view_as_real(est_target)
49
+ target = torch.view_as_real(target)
50
+
51
+
52
+ batch_size = est_target.shape[0]
53
+ est_target = est_target.reshape(batch_size, -1)
54
+ target = target.reshape(batch_size, -1)
55
+ # target_ = target_.reshape(batch_size, -1)
56
+
57
+ if self.p == 1:
58
+ e_error = torch.abs(est_target-target).mean(dim=-1)
59
+ e_target = torch.abs(target).mean(dim=-1)
60
+ elif self.p == 2:
61
+ e_error = torch.square(est_target-target).mean(dim=-1)
62
+ e_target = torch.square(target).mean(dim=-1)
63
+ else:
64
+ raise NotImplementedError
65
+
66
+ if self.take_log:
67
+ loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS))
68
+ else:
69
+ loss = (e_error + self.EPS)/(e_target + self.EPS)
70
+
71
+ if self.reduction == "mean":
72
+ loss = loss.mean()
73
+ elif self.reduction == "sum":
74
+ loss = loss.sum()
75
+
76
+ return loss
77
+
78
+
79
+
80
+ class MultichannelSingleSrcNegSDR(_Loss):
81
+ def __init__(
82
+ self,
83
+ sdr_type: str,
84
+ p: float = 2.0,
85
+ zero_mean: bool = True,
86
+ take_log: bool = True,
87
+ reduction: str = "mean",
88
+ EPS: float = 1e-8,
89
+ ) -> None:
90
+ assert reduction != "sum", NotImplementedError
91
+ super().__init__(reduction=reduction)
92
+
93
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
94
+ self.sdr_type = sdr_type
95
+ self.zero_mean = zero_mean
96
+ self.take_log = take_log
97
+ self.EPS = 1e-8
98
+
99
+ self.p = p
100
+
101
+ def forward(
102
+ self,
103
+ est_target: torch.Tensor,
104
+ target: torch.Tensor
105
+ ) -> torch.Tensor:
106
+ if target.size() != est_target.size() or target.ndim != 3:
107
+ raise TypeError(
108
+ f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
109
+ )
110
+ # Step 1. Zero-mean norm
111
+ if self.zero_mean:
112
+ mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
113
+ mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
114
+ target = target - mean_source
115
+ est_target = est_target - mean_estimate
116
+ # Step 2. Pair-wise SI-SDR.
117
+ if self.sdr_type in ["sisdr", "sdsdr"]:
118
+ # [batch, 1]
119
+ dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
120
+ # [batch, 1]
121
+ s_target_energy = (
122
+ torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS
123
+ )
124
+ # [batch, time]
125
+ scaled_target = dot * target / s_target_energy
126
+ else:
127
+ # [batch, time]
128
+ scaled_target = target
129
+ if self.sdr_type in ["sdsdr", "snr"]:
130
+ e_noise = est_target - target
131
+ else:
132
+ e_noise = est_target - scaled_target
133
+ # [batch]
134
+
135
+ if self.p == 2.0:
136
+ losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / (
137
+ torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS
138
+ )
139
+ else:
140
+ losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
141
+ torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
142
+ )
143
+ if self.take_log:
144
+ losses = 10 * torch.log10(losses + self.EPS)
145
+ losses = losses.mean() if self.reduction == "mean" else losses
146
+ return -losses
models/bandit/core/metrics/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .snr import (
2
+ ChunkMedianScaleInvariantSignalDistortionRatio,
3
+ ChunkMedianScaleInvariantSignalNoiseRatio,
4
+ ChunkMedianSignalDistortionRatio,
5
+ ChunkMedianSignalNoiseRatio,
6
+ SafeSignalDistortionRatio,
7
+ )
8
+
9
+ # from .mushra import EstimatedMushraScore
models/bandit/core/metrics/_squim.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from torchaudio._internal import load_state_dict_from_url
4
+
5
+ import math
6
+ from typing import List, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def transform_wb_pesq_range(x: float) -> float:
14
+ """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
15
+ for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
16
+ defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
17
+
18
+ Args:
19
+ x (float): Narrow-band PESQ score.
20
+
21
+ Returns:
22
+ (float): Wide-band PESQ score.
23
+ """
24
+ return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
25
+
26
+
27
+ PESQRange: Tuple[float, float] = (
28
+ 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
29
+ # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
30
+ # We are using 1.0 as a reasonable approximation.
31
+ transform_wb_pesq_range(4.5),
32
+ )
33
+
34
+
35
+ class RangeSigmoid(nn.Module):
36
+ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
37
+ super(RangeSigmoid, self).__init__()
38
+ assert isinstance(val_range, tuple) and len(val_range) == 2
39
+ self.val_range: Tuple[float, float] = val_range
40
+ self.sigmoid: nn.modules.Module = nn.Sigmoid()
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
44
+ return out
45
+
46
+
47
+ class Encoder(nn.Module):
48
+ """Encoder module that transform 1D waveform to 2D representations.
49
+
50
+ Args:
51
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
52
+ win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
53
+ """
54
+
55
+ def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
56
+ super(Encoder, self).__init__()
57
+
58
+ self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ """Apply waveforms to convolutional layer and ReLU layer.
62
+
63
+ Args:
64
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
65
+
66
+ Returns:
67
+ (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
68
+ """
69
+ out = x.unsqueeze(dim=1)
70
+ out = F.relu(self.conv1d(out))
71
+ return out
72
+
73
+
74
+ class SingleRNN(nn.Module):
75
+ def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
76
+ super(SingleRNN, self).__init__()
77
+
78
+ self.rnn_type = rnn_type
79
+ self.input_size = input_size
80
+ self.hidden_size = hidden_size
81
+
82
+ self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
83
+ input_size,
84
+ hidden_size,
85
+ 1,
86
+ dropout=dropout,
87
+ batch_first=True,
88
+ bidirectional=True,
89
+ )
90
+
91
+ self.proj = nn.Linear(hidden_size * 2, input_size)
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ # input shape: batch, seq, dim
95
+ out, _ = self.rnn(x)
96
+ out = self.proj(out)
97
+ return out
98
+
99
+
100
+ class DPRNN(nn.Module):
101
+ """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
102
+
103
+ Args:
104
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
105
+ hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
106
+ num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
107
+ rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
108
+ d_model (int, optional): The number of expected features in the input. (Default: 256)
109
+ chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
110
+ chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ feat_dim: int = 64,
116
+ hidden_dim: int = 128,
117
+ num_blocks: int = 6,
118
+ rnn_type: str = "LSTM",
119
+ d_model: int = 256,
120
+ chunk_size: int = 100,
121
+ chunk_stride: int = 50,
122
+ ) -> None:
123
+ super(DPRNN, self).__init__()
124
+
125
+ self.num_blocks = num_blocks
126
+
127
+ self.row_rnn = nn.ModuleList([])
128
+ self.col_rnn = nn.ModuleList([])
129
+ self.row_norm = nn.ModuleList([])
130
+ self.col_norm = nn.ModuleList([])
131
+ for _ in range(num_blocks):
132
+ self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
133
+ self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
134
+ self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
135
+ self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
136
+ self.conv = nn.Sequential(
137
+ nn.Conv2d(feat_dim, d_model, 1),
138
+ nn.PReLU(),
139
+ )
140
+ self.chunk_size = chunk_size
141
+ self.chunk_stride = chunk_stride
142
+
143
+ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
144
+ # input shape: (B, N, T)
145
+ seq_len = x.shape[-1]
146
+
147
+ rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
148
+ out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
149
+
150
+ return out, rest
151
+
152
+ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
153
+ out, rest = self.pad_chunk(x)
154
+ batch_size, feat_dim, seq_len = out.shape
155
+
156
+ segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
157
+ segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
158
+ out = torch.cat([segments1, segments2], dim=3)
159
+ out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
160
+
161
+ return out, rest
162
+
163
+ def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
164
+ batch_size, dim, _, _ = x.shape
165
+ out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
166
+ out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
167
+ out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
168
+ out = out1 + out2
169
+ if rest > 0:
170
+ out = out[:, :, :-rest]
171
+ out = out.contiguous()
172
+ return out
173
+
174
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
175
+ x, rest = self.chunking(x)
176
+ batch_size, _, dim1, dim2 = x.shape
177
+ out = x
178
+ for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
179
+ row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
180
+ row_out = row_rnn(row_in)
181
+ row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
182
+ row_out = row_norm(row_out)
183
+ out = out + row_out
184
+
185
+ col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
186
+ col_out = col_rnn(col_in)
187
+ col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
188
+ col_out = col_norm(col_out)
189
+ out = out + col_out
190
+ out = self.conv(out)
191
+ out = self.merging(out, rest)
192
+ out = out.transpose(1, 2).contiguous()
193
+ return out
194
+
195
+
196
+ class AutoPool(nn.Module):
197
+ def __init__(self, pool_dim: int = 1) -> None:
198
+ super(AutoPool, self).__init__()
199
+ self.pool_dim: int = pool_dim
200
+ self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
201
+ self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ weight = self.softmax(torch.mul(x, self.alpha))
205
+ out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
206
+ return out
207
+
208
+
209
+ class SquimObjective(nn.Module):
210
+ """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
211
+ for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
212
+
213
+ Args:
214
+ encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
215
+ dprnn (torch.nn.Module): DPRNN module to model sequential feature.
216
+ branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ encoder: nn.Module,
222
+ dprnn: nn.Module,
223
+ branches: nn.ModuleList,
224
+ ):
225
+ super(SquimObjective, self).__init__()
226
+ self.encoder = encoder
227
+ self.dprnn = dprnn
228
+ self.branches = branches
229
+
230
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
231
+ """
232
+ Args:
233
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
234
+
235
+ Returns:
236
+ List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
237
+ """
238
+ if x.ndim != 2:
239
+ raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
240
+ x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
241
+ out = self.encoder(x)
242
+ out = self.dprnn(out)
243
+ scores = []
244
+ for branch in self.branches:
245
+ scores.append(branch(out).squeeze(dim=1))
246
+ return scores
247
+
248
+
249
+ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
250
+ """Create branch module after DPRNN model for predicting metric score.
251
+
252
+ Args:
253
+ d_model (int): The number of expected features in the input.
254
+ nhead (int): Number of heads in the multi-head attention model.
255
+ metric (str): The metric name to predict.
256
+
257
+ Returns:
258
+ (nn.Module): Returned module to predict corresponding metric score.
259
+ """
260
+ layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
261
+ layer2 = AutoPool()
262
+ if metric == "stoi":
263
+ layer3 = nn.Sequential(
264
+ nn.Linear(d_model, d_model),
265
+ nn.PReLU(),
266
+ nn.Linear(d_model, 1),
267
+ RangeSigmoid(),
268
+ )
269
+ elif metric == "pesq":
270
+ layer3 = nn.Sequential(
271
+ nn.Linear(d_model, d_model),
272
+ nn.PReLU(),
273
+ nn.Linear(d_model, 1),
274
+ RangeSigmoid(val_range=PESQRange),
275
+ )
276
+ else:
277
+ layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
278
+ return nn.Sequential(layer1, layer2, layer3)
279
+
280
+
281
+ def squim_objective_model(
282
+ feat_dim: int,
283
+ win_len: int,
284
+ d_model: int,
285
+ nhead: int,
286
+ hidden_dim: int,
287
+ num_blocks: int,
288
+ rnn_type: str,
289
+ chunk_size: int,
290
+ chunk_stride: Optional[int] = None,
291
+ ) -> SquimObjective:
292
+ """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
293
+
294
+ Args:
295
+ feat_dim (int, optional): The feature dimension after Encoder module.
296
+ win_len (int): Kernel size in the Encoder module.
297
+ d_model (int): The number of expected features in the input.
298
+ nhead (int): Number of heads in the multi-head attention model.
299
+ hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
300
+ num_blocks (int): Number of DPRNN layers.
301
+ rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
302
+ chunk_size (int): Chunk size of input for DPRNN.
303
+ chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
304
+ """
305
+ if chunk_stride is None:
306
+ chunk_stride = chunk_size // 2
307
+ encoder = Encoder(feat_dim, win_len)
308
+ dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
309
+ branches = nn.ModuleList(
310
+ [
311
+ _create_branch(d_model, nhead, "stoi"),
312
+ _create_branch(d_model, nhead, "pesq"),
313
+ _create_branch(d_model, nhead, "sisdr"),
314
+ ]
315
+ )
316
+ return SquimObjective(encoder, dprnn, branches)
317
+
318
+
319
+ def squim_objective_base() -> SquimObjective:
320
+ """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
321
+ return squim_objective_model(
322
+ feat_dim=256,
323
+ win_len=64,
324
+ d_model=256,
325
+ nhead=4,
326
+ hidden_dim=256,
327
+ num_blocks=2,
328
+ rnn_type="LSTM",
329
+ chunk_size=71,
330
+ )
331
+
332
+ @dataclass
333
+ class SquimObjectiveBundle:
334
+
335
+ _path: str
336
+ _sample_rate: float
337
+
338
+ def _get_state_dict(self, dl_kwargs):
339
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
340
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
341
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
342
+ return state_dict
343
+
344
+ def get_model(self, *, dl_kwargs=None) -> SquimObjective:
345
+ """Construct the SquimObjective model, and load the pretrained weight.
346
+
347
+ The weight file is downloaded from the internet and cached with
348
+ :func:`torch.hub.load_state_dict_from_url`
349
+
350
+ Args:
351
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
352
+
353
+ Returns:
354
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
355
+ """
356
+ model = squim_objective_base()
357
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
358
+ model.eval()
359
+ return model
360
+
361
+ @property
362
+ def sample_rate(self):
363
+ """Sample rate of the audio that the model is trained on.
364
+
365
+ :type: float
366
+ """
367
+ return self._sample_rate
368
+
369
+
370
+ SQUIM_OBJECTIVE = SquimObjectiveBundle(
371
+ "squim_objective_dns2020.pth",
372
+ _sample_rate=16000,
373
+ )
374
+ SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
375
+ :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
376
+
377
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
378
+ The weights are under `Creative Commons Attribution 4.0 International License
379
+ <https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
380
+
381
+ Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
382
+ """
383
+
models/bandit/core/metrics/snr.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchmetrics as tm
6
+ from torch._C import _LinAlgError
7
+ from torchmetrics import functional as tmF
8
+
9
+
10
+ class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
11
+ def __init__(self, **kwargs) -> None:
12
+ super().__init__(**kwargs)
13
+
14
+ def update(self, *args, **kwargs) -> Any:
15
+ try:
16
+ super().update(*args, **kwargs)
17
+ except:
18
+ pass
19
+
20
+ def compute(self) -> Any:
21
+ if self.total == 0:
22
+ return torch.tensor(torch.nan)
23
+ return super().compute()
24
+
25
+
26
+ class BaseChunkMedianSignalRatio(tm.Metric):
27
+ def __init__(
28
+ self,
29
+ func: Callable,
30
+ window_size: int,
31
+ hop_size: int = None,
32
+ zero_mean: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+
36
+ # self.zero_mean = zero_mean
37
+ self.func = func
38
+ self.window_size = window_size
39
+ if hop_size is None:
40
+ hop_size = window_size
41
+ self.hop_size = hop_size
42
+
43
+ self.add_state(
44
+ "sum_snr",
45
+ default=torch.tensor(0.0),
46
+ dist_reduce_fx="sum"
47
+ )
48
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
49
+
50
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
51
+
52
+ n_samples = target.shape[-1]
53
+
54
+ n_chunks = int(
55
+ np.ceil((n_samples - self.window_size) / self.hop_size) + 1
56
+ )
57
+
58
+ snr_chunk = []
59
+
60
+ for i in range(n_chunks):
61
+ start = i * self.hop_size
62
+
63
+ if n_samples - start < self.window_size:
64
+ continue
65
+
66
+ end = start + self.window_size
67
+
68
+ try:
69
+ chunk_snr = self.func(
70
+ preds[..., start:end],
71
+ target[..., start:end]
72
+ )
73
+
74
+ # print(preds.shape, chunk_snr.shape)
75
+
76
+ if torch.all(torch.isfinite(chunk_snr)):
77
+ snr_chunk.append(chunk_snr)
78
+ except _LinAlgError:
79
+ pass
80
+
81
+ snr_chunk = torch.stack(snr_chunk, dim=-1)
82
+ snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
83
+
84
+ self.sum_snr += snr_batch.sum()
85
+ self.total += snr_batch.numel()
86
+
87
+ def compute(self) -> Any:
88
+ return self.sum_snr / self.total
89
+
90
+
91
+ class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
92
+ def __init__(
93
+ self,
94
+ window_size: int,
95
+ hop_size: int = None,
96
+ zero_mean: bool = False
97
+ ) -> None:
98
+ super().__init__(
99
+ func=tmF.signal_noise_ratio,
100
+ window_size=window_size,
101
+ hop_size=hop_size,
102
+ zero_mean=zero_mean,
103
+ )
104
+
105
+
106
+ class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
107
+ def __init__(
108
+ self,
109
+ window_size: int,
110
+ hop_size: int = None,
111
+ zero_mean: bool = False
112
+ ) -> None:
113
+ super().__init__(
114
+ func=tmF.scale_invariant_signal_noise_ratio,
115
+ window_size=window_size,
116
+ hop_size=hop_size,
117
+ zero_mean=zero_mean,
118
+ )
119
+
120
+
121
+ class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
122
+ def __init__(
123
+ self,
124
+ window_size: int,
125
+ hop_size: int = None,
126
+ zero_mean: bool = False
127
+ ) -> None:
128
+ super().__init__(
129
+ func=tmF.signal_distortion_ratio,
130
+ window_size=window_size,
131
+ hop_size=hop_size,
132
+ zero_mean=zero_mean,
133
+ )
134
+
135
+
136
+ class ChunkMedianScaleInvariantSignalDistortionRatio(
137
+ BaseChunkMedianSignalRatio
138
+ ):
139
+ def __init__(
140
+ self,
141
+ window_size: int,
142
+ hop_size: int = None,
143
+ zero_mean: bool = False
144
+ ) -> None:
145
+ super().__init__(
146
+ func=tmF.scale_invariant_signal_distortion_ratio,
147
+ window_size=window_size,
148
+ hop_size=hop_size,
149
+ zero_mean=zero_mean,
150
+ )
models/bandit/core/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .bsrnn.wrapper import (
2
+ MultiMaskMultiSourceBandSplitRNNSimple,
3
+ )
models/bandit/core/model/_spectral.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ import torch
4
+ import torchaudio as ta
5
+ from torch import nn
6
+
7
+
8
+ class _SpectralComponent(nn.Module):
9
+ def __init__(
10
+ self,
11
+ n_fft: int = 2048,
12
+ win_length: Optional[int] = 2048,
13
+ hop_length: int = 512,
14
+ window_fn: str = "hann_window",
15
+ wkwargs: Optional[Dict] = None,
16
+ power: Optional[int] = None,
17
+ center: bool = True,
18
+ normalized: bool = True,
19
+ pad_mode: str = "constant",
20
+ onesided: bool = True,
21
+ **kwargs,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ assert power is None
26
+
27
+ window_fn = torch.__dict__[window_fn]
28
+
29
+ self.stft = (
30
+ ta.transforms.Spectrogram(
31
+ n_fft=n_fft,
32
+ win_length=win_length,
33
+ hop_length=hop_length,
34
+ pad_mode=pad_mode,
35
+ pad=0,
36
+ window_fn=window_fn,
37
+ wkwargs=wkwargs,
38
+ power=power,
39
+ normalized=normalized,
40
+ center=center,
41
+ onesided=onesided,
42
+ )
43
+ )
44
+
45
+ self.istft = (
46
+ ta.transforms.InverseSpectrogram(
47
+ n_fft=n_fft,
48
+ win_length=win_length,
49
+ hop_length=hop_length,
50
+ pad_mode=pad_mode,
51
+ pad=0,
52
+ window_fn=window_fn,
53
+ wkwargs=wkwargs,
54
+ normalized=normalized,
55
+ center=center,
56
+ onesided=onesided,
57
+ )
58
+ )
models/bandit/core/model/bsrnn/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Iterable, Mapping, Union
3
+
4
+ from torch import nn
5
+
6
+ from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
7
+ from models.bandit.core.model.bsrnn.tfmodel import (
8
+ SeqBandModellingModule,
9
+ TransformerTimeFreqModule,
10
+ )
11
+
12
+
13
+ class BandsplitCoreBase(nn.Module, ABC):
14
+ band_split: nn.Module
15
+ tf_model: nn.Module
16
+ mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
17
+
18
+ def __init__(self) -> None:
19
+ super().__init__()
20
+
21
+ @staticmethod
22
+ def mask(x, m):
23
+ return x * m
models/bandit/core/model/bsrnn/bandsplit.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from models.bandit.core.model.bsrnn.utils import (
7
+ band_widths_from_specs,
8
+ check_no_gap,
9
+ check_no_overlap,
10
+ check_nonzero_bandwidth,
11
+ )
12
+
13
+
14
+ class NormFC(nn.Module):
15
+ def __init__(
16
+ self,
17
+ emb_dim: int,
18
+ bandwidth: int,
19
+ in_channel: int,
20
+ normalize_channel_independently: bool = False,
21
+ treat_channel_as_feature: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ self.treat_channel_as_feature = treat_channel_as_feature
26
+
27
+ if normalize_channel_independently:
28
+ raise NotImplementedError
29
+
30
+ reim = 2
31
+
32
+ self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
33
+
34
+ fc_in = bandwidth * reim
35
+
36
+ if treat_channel_as_feature:
37
+ fc_in *= in_channel
38
+ else:
39
+ assert emb_dim % in_channel == 0
40
+ emb_dim = emb_dim // in_channel
41
+
42
+ self.fc = nn.Linear(fc_in, emb_dim)
43
+
44
+ def forward(self, xb):
45
+ # xb = (batch, n_time, in_chan, reim * band_width)
46
+
47
+ batch, n_time, in_chan, ribw = xb.shape
48
+ xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
49
+ # (batch, n_time, in_chan * reim * band_width)
50
+
51
+ if not self.treat_channel_as_feature:
52
+ xb = xb.reshape(batch, n_time, in_chan, ribw)
53
+ # (batch, n_time, in_chan, reim * band_width)
54
+
55
+ zb = self.fc(xb)
56
+ # (batch, n_time, emb_dim)
57
+ # OR
58
+ # (batch, n_time, in_chan, emb_dim_per_chan)
59
+
60
+ if not self.treat_channel_as_feature:
61
+ batch, n_time, in_chan, emb_dim_per_chan = zb.shape
62
+ # (batch, n_time, in_chan, emb_dim_per_chan)
63
+ zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
64
+
65
+ return zb # (batch, n_time, emb_dim)
66
+
67
+
68
+ class BandSplitModule(nn.Module):
69
+ def __init__(
70
+ self,
71
+ band_specs: List[Tuple[float, float]],
72
+ emb_dim: int,
73
+ in_channel: int,
74
+ require_no_overlap: bool = False,
75
+ require_no_gap: bool = True,
76
+ normalize_channel_independently: bool = False,
77
+ treat_channel_as_feature: bool = True,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ check_nonzero_bandwidth(band_specs)
82
+
83
+ if require_no_gap:
84
+ check_no_gap(band_specs)
85
+
86
+ if require_no_overlap:
87
+ check_no_overlap(band_specs)
88
+
89
+ self.band_specs = band_specs
90
+ # list of [fstart, fend) in index.
91
+ # Note that fend is exclusive.
92
+ self.band_widths = band_widths_from_specs(band_specs)
93
+ self.n_bands = len(band_specs)
94
+ self.emb_dim = emb_dim
95
+
96
+ self.norm_fc_modules = nn.ModuleList(
97
+ [ # type: ignore
98
+ (
99
+ NormFC(
100
+ emb_dim=emb_dim,
101
+ bandwidth=bw,
102
+ in_channel=in_channel,
103
+ normalize_channel_independently=normalize_channel_independently,
104
+ treat_channel_as_feature=treat_channel_as_feature,
105
+ )
106
+ )
107
+ for bw in self.band_widths
108
+ ]
109
+ )
110
+
111
+ def forward(self, x: torch.Tensor):
112
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
113
+
114
+ batch, in_chan, _, n_time = x.shape
115
+
116
+ z = torch.zeros(
117
+ size=(batch, self.n_bands, n_time, self.emb_dim),
118
+ device=x.device
119
+ )
120
+
121
+ xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
122
+ xr = torch.permute(
123
+ xr,
124
+ (0, 3, 1, 4, 2)
125
+ ) # batch, n_time, in_chan, 2, n_freq
126
+ batch, n_time, in_chan, reim, band_width = xr.shape
127
+ for i, nfm in enumerate(self.norm_fc_modules):
128
+ # print(f"bandsplit/band{i:02d}")
129
+ fstart, fend = self.band_specs[i]
130
+ xb = xr[..., fstart:fend]
131
+ # (batch, n_time, in_chan, reim, band_width)
132
+ xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
133
+ # (batch, n_time, in_chan, reim * band_width)
134
+ # z.append(nfm(xb)) # (batch, n_time, emb_dim)
135
+ z[:, i, :, :] = nfm(xb.contiguous())
136
+
137
+ # z = torch.stack(z, dim=1)
138
+
139
+ return z
models/bandit/core/model/bsrnn/core.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from models.bandit.core.model.bsrnn import BandsplitCoreBase
8
+ from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
9
+ from models.bandit.core.model.bsrnn.maskestim import (
10
+ MaskEstimationModule,
11
+ OverlappingMaskEstimationModule
12
+ )
13
+ from models.bandit.core.model.bsrnn.tfmodel import (
14
+ ConvolutionalTimeFreqModule,
15
+ SeqBandModellingModule,
16
+ TransformerTimeFreqModule
17
+ )
18
+
19
+
20
+ class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
21
+ def __init__(self) -> None:
22
+ super().__init__()
23
+
24
+ def forward(self, x, cond=None, compute_residual: bool = True):
25
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
26
+ # print(x.shape)
27
+ batch, in_chan, n_freq, n_time = x.shape
28
+ x = torch.reshape(x, (-1, 1, n_freq, n_time))
29
+
30
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
31
+
32
+ # if torch.any(torch.isnan(z)):
33
+ # raise ValueError("z nan")
34
+
35
+ # print(z)
36
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
37
+ # print(q)
38
+
39
+
40
+ # if torch.any(torch.isnan(q)):
41
+ # raise ValueError("q nan")
42
+
43
+ out = {}
44
+
45
+ for stem, mem in self.mask_estim.items():
46
+ m = mem(q, cond=cond)
47
+
48
+ # if torch.any(torch.isnan(m)):
49
+ # raise ValueError("m nan", stem)
50
+
51
+ s = self.mask(x, m)
52
+ s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
53
+ out[stem] = s
54
+
55
+ return {"spectrogram": out}
56
+
57
+
58
+
59
+ def instantiate_mask_estim(self,
60
+ in_channel: int,
61
+ stems: List[str],
62
+ band_specs: List[Tuple[float, float]],
63
+ emb_dim: int,
64
+ mlp_dim: int,
65
+ cond_dim: int,
66
+ hidden_activation: str,
67
+
68
+ hidden_activation_kwargs: Optional[Dict] = None,
69
+ complex_mask: bool = True,
70
+ overlapping_band: bool = False,
71
+ freq_weights: Optional[List[torch.Tensor]] = None,
72
+ n_freq: Optional[int] = None,
73
+ use_freq_weights: bool = True,
74
+ mult_add_mask: bool = False
75
+ ):
76
+ if hidden_activation_kwargs is None:
77
+ hidden_activation_kwargs = {}
78
+
79
+ if "mne:+" in stems:
80
+ stems = [s for s in stems if s != "mne:+"]
81
+
82
+ if overlapping_band:
83
+ assert freq_weights is not None
84
+ assert n_freq is not None
85
+
86
+ if mult_add_mask:
87
+
88
+ self.mask_estim = nn.ModuleDict(
89
+ {
90
+ stem: MultAddMaskEstimationModule(
91
+ band_specs=band_specs,
92
+ freq_weights=freq_weights,
93
+ n_freq=n_freq,
94
+ emb_dim=emb_dim,
95
+ mlp_dim=mlp_dim,
96
+ in_channel=in_channel,
97
+ hidden_activation=hidden_activation,
98
+ hidden_activation_kwargs=hidden_activation_kwargs,
99
+ complex_mask=complex_mask,
100
+ use_freq_weights=use_freq_weights,
101
+ )
102
+ for stem in stems
103
+ }
104
+ )
105
+ else:
106
+ self.mask_estim = nn.ModuleDict(
107
+ {
108
+ stem: OverlappingMaskEstimationModule(
109
+ band_specs=band_specs,
110
+ freq_weights=freq_weights,
111
+ n_freq=n_freq,
112
+ emb_dim=emb_dim,
113
+ mlp_dim=mlp_dim,
114
+ in_channel=in_channel,
115
+ hidden_activation=hidden_activation,
116
+ hidden_activation_kwargs=hidden_activation_kwargs,
117
+ complex_mask=complex_mask,
118
+ use_freq_weights=use_freq_weights,
119
+ )
120
+ for stem in stems
121
+ }
122
+ )
123
+ else:
124
+ self.mask_estim = nn.ModuleDict(
125
+ {
126
+ stem: MaskEstimationModule(
127
+ band_specs=band_specs,
128
+ emb_dim=emb_dim,
129
+ mlp_dim=mlp_dim,
130
+ cond_dim=cond_dim,
131
+ in_channel=in_channel,
132
+ hidden_activation=hidden_activation,
133
+ hidden_activation_kwargs=hidden_activation_kwargs,
134
+ complex_mask=complex_mask,
135
+ )
136
+ for stem in stems
137
+ }
138
+ )
139
+
140
+ def instantiate_bandsplit(self,
141
+ in_channel: int,
142
+ band_specs: List[Tuple[float, float]],
143
+ require_no_overlap: bool = False,
144
+ require_no_gap: bool = True,
145
+ normalize_channel_independently: bool = False,
146
+ treat_channel_as_feature: bool = True,
147
+ emb_dim: int = 128
148
+ ):
149
+ self.band_split = BandSplitModule(
150
+ in_channel=in_channel,
151
+ band_specs=band_specs,
152
+ require_no_overlap=require_no_overlap,
153
+ require_no_gap=require_no_gap,
154
+ normalize_channel_independently=normalize_channel_independently,
155
+ treat_channel_as_feature=treat_channel_as_feature,
156
+ emb_dim=emb_dim,
157
+ )
158
+
159
+ class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
160
+ def __init__(self, **kwargs) -> None:
161
+ super().__init__()
162
+
163
+ def forward(self, x):
164
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
165
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
166
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
167
+ m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time)
168
+
169
+ s = self.mask(x, m)
170
+
171
+ return s
172
+
173
+
174
+ class SingleMaskBandsplitCoreRNN(
175
+ SingleMaskBandsplitCoreBase,
176
+ ):
177
+ def __init__(
178
+ self,
179
+ in_channel: int,
180
+ band_specs: List[Tuple[float, float]],
181
+ require_no_overlap: bool = False,
182
+ require_no_gap: bool = True,
183
+ normalize_channel_independently: bool = False,
184
+ treat_channel_as_feature: bool = True,
185
+ n_sqm_modules: int = 12,
186
+ emb_dim: int = 128,
187
+ rnn_dim: int = 256,
188
+ bidirectional: bool = True,
189
+ rnn_type: str = "LSTM",
190
+ mlp_dim: int = 512,
191
+ hidden_activation: str = "Tanh",
192
+ hidden_activation_kwargs: Optional[Dict] = None,
193
+ complex_mask: bool = True,
194
+ ) -> None:
195
+ super().__init__()
196
+ self.band_split = (BandSplitModule(
197
+ in_channel=in_channel,
198
+ band_specs=band_specs,
199
+ require_no_overlap=require_no_overlap,
200
+ require_no_gap=require_no_gap,
201
+ normalize_channel_independently=normalize_channel_independently,
202
+ treat_channel_as_feature=treat_channel_as_feature,
203
+ emb_dim=emb_dim,
204
+ ))
205
+ self.tf_model = (SeqBandModellingModule(
206
+ n_modules=n_sqm_modules,
207
+ emb_dim=emb_dim,
208
+ rnn_dim=rnn_dim,
209
+ bidirectional=bidirectional,
210
+ rnn_type=rnn_type,
211
+ ))
212
+ self.mask_estim = (MaskEstimationModule(
213
+ in_channel=in_channel,
214
+ band_specs=band_specs,
215
+ emb_dim=emb_dim,
216
+ mlp_dim=mlp_dim,
217
+ hidden_activation=hidden_activation,
218
+ hidden_activation_kwargs=hidden_activation_kwargs,
219
+ complex_mask=complex_mask,
220
+ ))
221
+
222
+
223
+ class SingleMaskBandsplitCoreTransformer(
224
+ SingleMaskBandsplitCoreBase,
225
+ ):
226
+ def __init__(
227
+ self,
228
+ in_channel: int,
229
+ band_specs: List[Tuple[float, float]],
230
+ require_no_overlap: bool = False,
231
+ require_no_gap: bool = True,
232
+ normalize_channel_independently: bool = False,
233
+ treat_channel_as_feature: bool = True,
234
+ n_sqm_modules: int = 12,
235
+ emb_dim: int = 128,
236
+ rnn_dim: int = 256,
237
+ bidirectional: bool = True,
238
+ tf_dropout: float = 0.0,
239
+ mlp_dim: int = 512,
240
+ hidden_activation: str = "Tanh",
241
+ hidden_activation_kwargs: Optional[Dict] = None,
242
+ complex_mask: bool = True,
243
+ ) -> None:
244
+ super().__init__()
245
+ self.band_split = BandSplitModule(
246
+ in_channel=in_channel,
247
+ band_specs=band_specs,
248
+ require_no_overlap=require_no_overlap,
249
+ require_no_gap=require_no_gap,
250
+ normalize_channel_independently=normalize_channel_independently,
251
+ treat_channel_as_feature=treat_channel_as_feature,
252
+ emb_dim=emb_dim,
253
+ )
254
+ self.tf_model = TransformerTimeFreqModule(
255
+ n_modules=n_sqm_modules,
256
+ emb_dim=emb_dim,
257
+ rnn_dim=rnn_dim,
258
+ bidirectional=bidirectional,
259
+ dropout=tf_dropout,
260
+ )
261
+ self.mask_estim = MaskEstimationModule(
262
+ in_channel=in_channel,
263
+ band_specs=band_specs,
264
+ emb_dim=emb_dim,
265
+ mlp_dim=mlp_dim,
266
+ hidden_activation=hidden_activation,
267
+ hidden_activation_kwargs=hidden_activation_kwargs,
268
+ complex_mask=complex_mask,
269
+ )
270
+
271
+
272
+ class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
273
+ def __init__(
274
+ self,
275
+ in_channel: int,
276
+ stems: List[str],
277
+ band_specs: List[Tuple[float, float]],
278
+ require_no_overlap: bool = False,
279
+ require_no_gap: bool = True,
280
+ normalize_channel_independently: bool = False,
281
+ treat_channel_as_feature: bool = True,
282
+ n_sqm_modules: int = 12,
283
+ emb_dim: int = 128,
284
+ rnn_dim: int = 256,
285
+ bidirectional: bool = True,
286
+ rnn_type: str = "LSTM",
287
+ mlp_dim: int = 512,
288
+ cond_dim: int = 0,
289
+ hidden_activation: str = "Tanh",
290
+ hidden_activation_kwargs: Optional[Dict] = None,
291
+ complex_mask: bool = True,
292
+ overlapping_band: bool = False,
293
+ freq_weights: Optional[List[torch.Tensor]] = None,
294
+ n_freq: Optional[int] = None,
295
+ use_freq_weights: bool = True,
296
+ mult_add_mask: bool = False
297
+ ) -> None:
298
+
299
+ super().__init__()
300
+ self.instantiate_bandsplit(
301
+ in_channel=in_channel,
302
+ band_specs=band_specs,
303
+ require_no_overlap=require_no_overlap,
304
+ require_no_gap=require_no_gap,
305
+ normalize_channel_independently=normalize_channel_independently,
306
+ treat_channel_as_feature=treat_channel_as_feature,
307
+ emb_dim=emb_dim
308
+ )
309
+
310
+
311
+ self.tf_model = (
312
+ SeqBandModellingModule(
313
+ n_modules=n_sqm_modules,
314
+ emb_dim=emb_dim,
315
+ rnn_dim=rnn_dim,
316
+ bidirectional=bidirectional,
317
+ rnn_type=rnn_type,
318
+ )
319
+ )
320
+
321
+ self.mult_add_mask = mult_add_mask
322
+
323
+ self.instantiate_mask_estim(
324
+ in_channel=in_channel,
325
+ stems=stems,
326
+ band_specs=band_specs,
327
+ emb_dim=emb_dim,
328
+ mlp_dim=mlp_dim,
329
+ cond_dim=cond_dim,
330
+ hidden_activation=hidden_activation,
331
+ hidden_activation_kwargs=hidden_activation_kwargs,
332
+ complex_mask=complex_mask,
333
+ overlapping_band=overlapping_band,
334
+ freq_weights=freq_weights,
335
+ n_freq=n_freq,
336
+ use_freq_weights=use_freq_weights,
337
+ mult_add_mask=mult_add_mask
338
+ )
339
+
340
+ @staticmethod
341
+ def _mult_add_mask(x, m):
342
+
343
+ assert m.ndim == 5
344
+
345
+ mm = m[..., 0]
346
+ am = m[..., 1]
347
+
348
+ # print(mm.shape, am.shape, x.shape, m.shape)
349
+
350
+ return x * mm + am
351
+
352
+ def mask(self, x, m):
353
+ if self.mult_add_mask:
354
+
355
+ return self._mult_add_mask(x, m)
356
+ else:
357
+ return super().mask(x, m)
358
+
359
+
360
+ class MultiSourceMultiMaskBandSplitCoreTransformer(
361
+ MultiMaskBandSplitCoreBase,
362
+ ):
363
+ def __init__(
364
+ self,
365
+ in_channel: int,
366
+ stems: List[str],
367
+ band_specs: List[Tuple[float, float]],
368
+ require_no_overlap: bool = False,
369
+ require_no_gap: bool = True,
370
+ normalize_channel_independently: bool = False,
371
+ treat_channel_as_feature: bool = True,
372
+ n_sqm_modules: int = 12,
373
+ emb_dim: int = 128,
374
+ rnn_dim: int = 256,
375
+ bidirectional: bool = True,
376
+ tf_dropout: float = 0.0,
377
+ mlp_dim: int = 512,
378
+ hidden_activation: str = "Tanh",
379
+ hidden_activation_kwargs: Optional[Dict] = None,
380
+ complex_mask: bool = True,
381
+ overlapping_band: bool = False,
382
+ freq_weights: Optional[List[torch.Tensor]] = None,
383
+ n_freq: Optional[int] = None,
384
+ use_freq_weights:bool=True,
385
+ rnn_type: str = "LSTM",
386
+ cond_dim: int = 0,
387
+ mult_add_mask: bool = False
388
+ ) -> None:
389
+ super().__init__()
390
+ self.instantiate_bandsplit(
391
+ in_channel=in_channel,
392
+ band_specs=band_specs,
393
+ require_no_overlap=require_no_overlap,
394
+ require_no_gap=require_no_gap,
395
+ normalize_channel_independently=normalize_channel_independently,
396
+ treat_channel_as_feature=treat_channel_as_feature,
397
+ emb_dim=emb_dim
398
+ )
399
+ self.tf_model = TransformerTimeFreqModule(
400
+ n_modules=n_sqm_modules,
401
+ emb_dim=emb_dim,
402
+ rnn_dim=rnn_dim,
403
+ bidirectional=bidirectional,
404
+ dropout=tf_dropout,
405
+ )
406
+
407
+ self.instantiate_mask_estim(
408
+ in_channel=in_channel,
409
+ stems=stems,
410
+ band_specs=band_specs,
411
+ emb_dim=emb_dim,
412
+ mlp_dim=mlp_dim,
413
+ cond_dim=cond_dim,
414
+ hidden_activation=hidden_activation,
415
+ hidden_activation_kwargs=hidden_activation_kwargs,
416
+ complex_mask=complex_mask,
417
+ overlapping_band=overlapping_band,
418
+ freq_weights=freq_weights,
419
+ n_freq=n_freq,
420
+ use_freq_weights=use_freq_weights,
421
+ mult_add_mask=mult_add_mask
422
+ )
423
+
424
+
425
+
426
+ class MultiSourceMultiMaskBandSplitCoreConv(
427
+ MultiMaskBandSplitCoreBase,
428
+ ):
429
+ def __init__(
430
+ self,
431
+ in_channel: int,
432
+ stems: List[str],
433
+ band_specs: List[Tuple[float, float]],
434
+ require_no_overlap: bool = False,
435
+ require_no_gap: bool = True,
436
+ normalize_channel_independently: bool = False,
437
+ treat_channel_as_feature: bool = True,
438
+ n_sqm_modules: int = 12,
439
+ emb_dim: int = 128,
440
+ rnn_dim: int = 256,
441
+ bidirectional: bool = True,
442
+ tf_dropout: float = 0.0,
443
+ mlp_dim: int = 512,
444
+ hidden_activation: str = "Tanh",
445
+ hidden_activation_kwargs: Optional[Dict] = None,
446
+ complex_mask: bool = True,
447
+ overlapping_band: bool = False,
448
+ freq_weights: Optional[List[torch.Tensor]] = None,
449
+ n_freq: Optional[int] = None,
450
+ use_freq_weights:bool=True,
451
+ rnn_type: str = "LSTM",
452
+ cond_dim: int = 0,
453
+ mult_add_mask: bool = False
454
+ ) -> None:
455
+ super().__init__()
456
+ self.instantiate_bandsplit(
457
+ in_channel=in_channel,
458
+ band_specs=band_specs,
459
+ require_no_overlap=require_no_overlap,
460
+ require_no_gap=require_no_gap,
461
+ normalize_channel_independently=normalize_channel_independently,
462
+ treat_channel_as_feature=treat_channel_as_feature,
463
+ emb_dim=emb_dim
464
+ )
465
+ self.tf_model = ConvolutionalTimeFreqModule(
466
+ n_modules=n_sqm_modules,
467
+ emb_dim=emb_dim,
468
+ rnn_dim=rnn_dim,
469
+ bidirectional=bidirectional,
470
+ dropout=tf_dropout,
471
+ )
472
+
473
+ self.instantiate_mask_estim(
474
+ in_channel=in_channel,
475
+ stems=stems,
476
+ band_specs=band_specs,
477
+ emb_dim=emb_dim,
478
+ mlp_dim=mlp_dim,
479
+ cond_dim=cond_dim,
480
+ hidden_activation=hidden_activation,
481
+ hidden_activation_kwargs=hidden_activation_kwargs,
482
+ complex_mask=complex_mask,
483
+ overlapping_band=overlapping_band,
484
+ freq_weights=freq_weights,
485
+ n_freq=n_freq,
486
+ use_freq_weights=use_freq_weights,
487
+ mult_add_mask=mult_add_mask
488
+ )
489
+
490
+
491
+ class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
492
+ def __init__(self) -> None:
493
+ super().__init__()
494
+
495
+ def mask(self, x, m):
496
+ # x.shape = (batch, n_channel, n_freq, n_time)
497
+ # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
498
+
499
+ _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
500
+ padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
501
+
502
+ xf = F.unfold(
503
+ x,
504
+ kernel_size=(kernel_freq, kernel_time),
505
+ padding=padding,
506
+ stride=(1, 1),
507
+ )
508
+
509
+ xf = xf.view(
510
+ -1,
511
+ n_channel,
512
+ kernel_freq,
513
+ kernel_time,
514
+ n_freq,
515
+ n_time,
516
+ )
517
+
518
+ sf = xf * m
519
+
520
+ sf = sf.view(
521
+ -1,
522
+ n_channel * kernel_freq * kernel_time,
523
+ n_freq * n_time,
524
+ )
525
+
526
+ s = F.fold(
527
+ sf,
528
+ output_size=(n_freq, n_time),
529
+ kernel_size=(kernel_freq, kernel_time),
530
+ padding=padding,
531
+ stride=(1, 1),
532
+ ).view(
533
+ -1,
534
+ n_channel,
535
+ n_freq,
536
+ n_time,
537
+ )
538
+
539
+ return s
540
+
541
+ def old_mask(self, x, m):
542
+ # x.shape = (batch, n_channel, n_freq, n_time)
543
+ # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
544
+
545
+ s = torch.zeros_like(x)
546
+
547
+ _, n_channel, n_freq, n_time = x.shape
548
+ kernel_freq, kernel_time, _, _, _, _ = m.shape
549
+
550
+ # print(x.shape, m.shape)
551
+
552
+ kernel_freq_half = (kernel_freq - 1) // 2
553
+ kernel_time_half = (kernel_time - 1) // 2
554
+
555
+ for ifreq in range(kernel_freq):
556
+ for itime in range(kernel_time):
557
+ df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
558
+ x = x.roll(shifts=(df, dt), dims=(2, 3))
559
+
560
+ # if `df` > 0:
561
+ # x[:, :, :df, :] = 0
562
+ # elif `df` < 0:
563
+ # x[:, :, df:, :] = 0
564
+
565
+ # if `dt` > 0:
566
+ # x[:, :, :, :dt] = 0
567
+ # elif `dt` < 0:
568
+ # x[:, :, :, dt:] = 0
569
+
570
+ fslice = slice(max(0, df), min(n_freq, n_freq + df))
571
+ tslice = slice(max(0, dt), min(n_time, n_time + dt))
572
+
573
+ s[:, :, fslice, tslice] += x[:, :, fslice, tslice] * m[ifreq,
574
+ itime, :,
575
+ :, fslice,
576
+ tslice]
577
+
578
+ return s
579
+
580
+
581
+ class MultiSourceMultiPatchingMaskBandSplitCoreRNN(
582
+ PatchingMaskBandsplitCoreBase
583
+ ):
584
+ def __init__(
585
+ self,
586
+ in_channel: int,
587
+ stems: List[str],
588
+ band_specs: List[Tuple[float, float]],
589
+ mask_kernel_freq: int,
590
+ mask_kernel_time: int,
591
+ conv_kernel_freq: int,
592
+ conv_kernel_time: int,
593
+ kernel_norm_mlp_version: int,
594
+ require_no_overlap: bool = False,
595
+ require_no_gap: bool = True,
596
+ normalize_channel_independently: bool = False,
597
+ treat_channel_as_feature: bool = True,
598
+ n_sqm_modules: int = 12,
599
+ emb_dim: int = 128,
600
+ rnn_dim: int = 256,
601
+ bidirectional: bool = True,
602
+ rnn_type: str = "LSTM",
603
+ mlp_dim: int = 512,
604
+ hidden_activation: str = "Tanh",
605
+ hidden_activation_kwargs: Optional[Dict] = None,
606
+ complex_mask: bool = True,
607
+ overlapping_band: bool = False,
608
+ freq_weights: Optional[List[torch.Tensor]] = None,
609
+ n_freq: Optional[int] = None,
610
+ ) -> None:
611
+
612
+ super().__init__()
613
+ self.band_split = BandSplitModule(
614
+ in_channel=in_channel,
615
+ band_specs=band_specs,
616
+ require_no_overlap=require_no_overlap,
617
+ require_no_gap=require_no_gap,
618
+ normalize_channel_independently=normalize_channel_independently,
619
+ treat_channel_as_feature=treat_channel_as_feature,
620
+ emb_dim=emb_dim,
621
+ )
622
+
623
+ self.tf_model = (
624
+ SeqBandModellingModule(
625
+ n_modules=n_sqm_modules,
626
+ emb_dim=emb_dim,
627
+ rnn_dim=rnn_dim,
628
+ bidirectional=bidirectional,
629
+ rnn_type=rnn_type,
630
+ )
631
+ )
632
+
633
+ if hidden_activation_kwargs is None:
634
+ hidden_activation_kwargs = {}
635
+
636
+ if overlapping_band:
637
+ assert freq_weights is not None
638
+ assert n_freq is not None
639
+ self.mask_estim = nn.ModuleDict(
640
+ {
641
+ stem: PatchingMaskEstimationModule(
642
+ band_specs=band_specs,
643
+ freq_weights=freq_weights,
644
+ n_freq=n_freq,
645
+ emb_dim=emb_dim,
646
+ mlp_dim=mlp_dim,
647
+ in_channel=in_channel,
648
+ hidden_activation=hidden_activation,
649
+ hidden_activation_kwargs=hidden_activation_kwargs,
650
+ complex_mask=complex_mask,
651
+ mask_kernel_freq=mask_kernel_freq,
652
+ mask_kernel_time=mask_kernel_time,
653
+ conv_kernel_freq=conv_kernel_freq,
654
+ conv_kernel_time=conv_kernel_time,
655
+ kernel_norm_mlp_version=kernel_norm_mlp_version
656
+ )
657
+ for stem in stems
658
+ }
659
+ )
660
+ else:
661
+ raise NotImplementedError
models/bandit/core/model/bsrnn/maskestim.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Dict, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn.modules import activation
7
+
8
+ from models.bandit.core.model.bsrnn.utils import (
9
+ band_widths_from_specs,
10
+ check_no_gap,
11
+ check_no_overlap,
12
+ check_nonzero_bandwidth,
13
+ )
14
+
15
+
16
+ class BaseNormMLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ mlp_dim: int,
21
+ bandwidth: int,
22
+ in_channel: Optional[int],
23
+ hidden_activation: str = "Tanh",
24
+ hidden_activation_kwargs=None,
25
+ complex_mask: bool = True, ):
26
+
27
+ super().__init__()
28
+ if hidden_activation_kwargs is None:
29
+ hidden_activation_kwargs = {}
30
+ self.hidden_activation_kwargs = hidden_activation_kwargs
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ self.hidden = torch.jit.script(nn.Sequential(
33
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
34
+ activation.__dict__[hidden_activation](
35
+ **self.hidden_activation_kwargs
36
+ ),
37
+ ))
38
+
39
+ self.bandwidth = bandwidth
40
+ self.in_channel = in_channel
41
+
42
+ self.complex_mask = complex_mask
43
+ self.reim = 2 if complex_mask else 1
44
+ self.glu_mult = 2
45
+
46
+
47
+ class NormMLP(BaseNormMLP):
48
+ def __init__(
49
+ self,
50
+ emb_dim: int,
51
+ mlp_dim: int,
52
+ bandwidth: int,
53
+ in_channel: Optional[int],
54
+ hidden_activation: str = "Tanh",
55
+ hidden_activation_kwargs=None,
56
+ complex_mask: bool = True,
57
+ ) -> None:
58
+ super().__init__(
59
+ emb_dim=emb_dim,
60
+ mlp_dim=mlp_dim,
61
+ bandwidth=bandwidth,
62
+ in_channel=in_channel,
63
+ hidden_activation=hidden_activation,
64
+ hidden_activation_kwargs=hidden_activation_kwargs,
65
+ complex_mask=complex_mask,
66
+ )
67
+
68
+ self.output = torch.jit.script(
69
+ nn.Sequential(
70
+ nn.Linear(
71
+ in_features=mlp_dim,
72
+ out_features=bandwidth * in_channel * self.reim * 2,
73
+ ),
74
+ nn.GLU(dim=-1),
75
+ )
76
+ )
77
+
78
+ def reshape_output(self, mb):
79
+ # print(mb.shape)
80
+ batch, n_time, _ = mb.shape
81
+ if self.complex_mask:
82
+ mb = mb.reshape(
83
+ batch,
84
+ n_time,
85
+ self.in_channel,
86
+ self.bandwidth,
87
+ self.reim
88
+ ).contiguous()
89
+ # print(mb.shape)
90
+ mb = torch.view_as_complex(
91
+ mb
92
+ ) # (batch, n_time, in_channel, bandwidth)
93
+ else:
94
+ mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
95
+
96
+ mb = torch.permute(
97
+ mb,
98
+ (0, 2, 3, 1)
99
+ ) # (batch, in_channel, bandwidth, n_time)
100
+
101
+ return mb
102
+
103
+ def forward(self, qb):
104
+ # qb = (batch, n_time, emb_dim)
105
+
106
+ # if torch.any(torch.isnan(qb)):
107
+ # raise ValueError("qb0")
108
+
109
+
110
+ qb = self.norm(qb) # (batch, n_time, emb_dim)
111
+
112
+ # if torch.any(torch.isnan(qb)):
113
+ # raise ValueError("qb1")
114
+
115
+ qb = self.hidden(qb) # (batch, n_time, mlp_dim)
116
+ # if torch.any(torch.isnan(qb)):
117
+ # raise ValueError("qb2")
118
+ mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
119
+ # if torch.any(torch.isnan(qb)):
120
+ # raise ValueError("mb")
121
+ mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time)
122
+
123
+ return mb
124
+
125
+
126
+ class MultAddNormMLP(NormMLP):
127
+ def __init__(self, emb_dim: int, mlp_dim: int, bandwidth: int, in_channel: "int | None", hidden_activation: str = "Tanh", hidden_activation_kwargs=None, complex_mask: bool = True) -> None:
128
+ super().__init__(emb_dim, mlp_dim, bandwidth, in_channel, hidden_activation, hidden_activation_kwargs, complex_mask)
129
+
130
+ self.output2 = torch.jit.script(
131
+ nn.Sequential(
132
+ nn.Linear(
133
+ in_features=mlp_dim,
134
+ out_features=bandwidth * in_channel * self.reim * 2,
135
+ ),
136
+ nn.GLU(dim=-1),
137
+ )
138
+ )
139
+
140
+ def forward(self, qb):
141
+
142
+ qb = self.norm(qb) # (batch, n_time, emb_dim)
143
+ qb = self.hidden(qb) # (batch, n_time, mlp_dim)
144
+ mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
145
+ mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time)
146
+ amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim)
147
+ amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time)
148
+
149
+ return mmb, amb
150
+
151
+
152
+ class MaskEstimationModuleSuperBase(nn.Module):
153
+ pass
154
+
155
+
156
+ class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
157
+ def __init__(
158
+ self,
159
+ band_specs: List[Tuple[float, float]],
160
+ emb_dim: int,
161
+ mlp_dim: int,
162
+ in_channel: Optional[int],
163
+ hidden_activation: str = "Tanh",
164
+ hidden_activation_kwargs: Dict = None,
165
+ complex_mask: bool = True,
166
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
167
+ norm_mlp_kwargs: Dict = None,
168
+ ) -> None:
169
+ super().__init__()
170
+
171
+ self.band_widths = band_widths_from_specs(band_specs)
172
+ self.n_bands = len(band_specs)
173
+
174
+ if hidden_activation_kwargs is None:
175
+ hidden_activation_kwargs = {}
176
+
177
+ if norm_mlp_kwargs is None:
178
+ norm_mlp_kwargs = {}
179
+
180
+ self.norm_mlp = nn.ModuleList(
181
+ [
182
+ (
183
+ norm_mlp_cls(
184
+ bandwidth=self.band_widths[b],
185
+ emb_dim=emb_dim,
186
+ mlp_dim=mlp_dim,
187
+ in_channel=in_channel,
188
+ hidden_activation=hidden_activation,
189
+ hidden_activation_kwargs=hidden_activation_kwargs,
190
+ complex_mask=complex_mask,
191
+ **norm_mlp_kwargs,
192
+ )
193
+ )
194
+ for b in range(self.n_bands)
195
+ ]
196
+ )
197
+
198
+ def compute_masks(self, q):
199
+ batch, n_bands, n_time, emb_dim = q.shape
200
+
201
+ masks = []
202
+
203
+ for b, nmlp in enumerate(self.norm_mlp):
204
+ # print(f"maskestim/{b:02d}")
205
+ qb = q[:, b, :, :]
206
+ mb = nmlp(qb)
207
+ masks.append(mb)
208
+
209
+ return masks
210
+
211
+
212
+
213
+ class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
214
+ def __init__(
215
+ self,
216
+ in_channel: int,
217
+ band_specs: List[Tuple[float, float]],
218
+ freq_weights: List[torch.Tensor],
219
+ n_freq: int,
220
+ emb_dim: int,
221
+ mlp_dim: int,
222
+ cond_dim: int = 0,
223
+ hidden_activation: str = "Tanh",
224
+ hidden_activation_kwargs: Dict = None,
225
+ complex_mask: bool = True,
226
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
227
+ norm_mlp_kwargs: Dict = None,
228
+ use_freq_weights: bool = True,
229
+ ) -> None:
230
+ check_nonzero_bandwidth(band_specs)
231
+ check_no_gap(band_specs)
232
+
233
+ # if cond_dim > 0:
234
+ # raise NotImplementedError
235
+
236
+ super().__init__(
237
+ band_specs=band_specs,
238
+ emb_dim=emb_dim + cond_dim,
239
+ mlp_dim=mlp_dim,
240
+ in_channel=in_channel,
241
+ hidden_activation=hidden_activation,
242
+ hidden_activation_kwargs=hidden_activation_kwargs,
243
+ complex_mask=complex_mask,
244
+ norm_mlp_cls=norm_mlp_cls,
245
+ norm_mlp_kwargs=norm_mlp_kwargs,
246
+ )
247
+
248
+ self.n_freq = n_freq
249
+ self.band_specs = band_specs
250
+ self.in_channel = in_channel
251
+
252
+ if freq_weights is not None:
253
+ for i, fw in enumerate(freq_weights):
254
+ self.register_buffer(f"freq_weights/{i}", fw)
255
+
256
+ self.use_freq_weights = use_freq_weights
257
+ else:
258
+ self.use_freq_weights = False
259
+
260
+ self.cond_dim = cond_dim
261
+
262
+ def forward(self, q, cond=None):
263
+ # q = (batch, n_bands, n_time, emb_dim)
264
+
265
+ batch, n_bands, n_time, emb_dim = q.shape
266
+
267
+ if cond is not None:
268
+ print(cond)
269
+ if cond.ndim == 2:
270
+ cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
271
+ elif cond.ndim == 3:
272
+ assert cond.shape[1] == n_time
273
+ else:
274
+ raise ValueError(f"Invalid cond shape: {cond.shape}")
275
+
276
+ q = torch.cat([q, cond], dim=-1)
277
+ elif self.cond_dim > 0:
278
+ cond = torch.ones(
279
+ (batch, n_bands, n_time, self.cond_dim),
280
+ device=q.device,
281
+ dtype=q.dtype,
282
+ )
283
+ q = torch.cat([q, cond], dim=-1)
284
+ else:
285
+ pass
286
+
287
+ mask_list = self.compute_masks(
288
+ q
289
+ ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
290
+
291
+ masks = torch.zeros(
292
+ (batch, self.in_channel, self.n_freq, n_time),
293
+ device=q.device,
294
+ dtype=mask_list[0].dtype,
295
+ )
296
+
297
+ for im, mask in enumerate(mask_list):
298
+ fstart, fend = self.band_specs[im]
299
+ if self.use_freq_weights:
300
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
301
+ mask = mask * fw
302
+ masks[:, :, fstart:fend, :] += mask
303
+
304
+ return masks
305
+
306
+
307
+ class MaskEstimationModule(OverlappingMaskEstimationModule):
308
+ def __init__(
309
+ self,
310
+ band_specs: List[Tuple[float, float]],
311
+ emb_dim: int,
312
+ mlp_dim: int,
313
+ in_channel: Optional[int],
314
+ hidden_activation: str = "Tanh",
315
+ hidden_activation_kwargs: Dict = None,
316
+ complex_mask: bool = True,
317
+ **kwargs,
318
+ ) -> None:
319
+ check_nonzero_bandwidth(band_specs)
320
+ check_no_gap(band_specs)
321
+ check_no_overlap(band_specs)
322
+ super().__init__(
323
+ in_channel=in_channel,
324
+ band_specs=band_specs,
325
+ freq_weights=None,
326
+ n_freq=None,
327
+ emb_dim=emb_dim,
328
+ mlp_dim=mlp_dim,
329
+ hidden_activation=hidden_activation,
330
+ hidden_activation_kwargs=hidden_activation_kwargs,
331
+ complex_mask=complex_mask,
332
+ )
333
+
334
+ def forward(self, q, cond=None):
335
+ # q = (batch, n_bands, n_time, emb_dim)
336
+
337
+ masks = self.compute_masks(
338
+ q
339
+ ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
340
+
341
+ # TODO: currently this requires band specs to have no gap and no overlap
342
+ masks = torch.concat(
343
+ masks,
344
+ dim=2
345
+ ) # (batch, in_channel, n_freq, n_time)
346
+
347
+ return masks
models/bandit/core/model/bsrnn/tfmodel.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.nn.modules import rnn
7
+
8
+ import torch.backends.cuda
9
+
10
+
11
+ class TimeFrequencyModellingModule(nn.Module):
12
+ def __init__(self) -> None:
13
+ super().__init__()
14
+
15
+
16
+ class ResidualRNN(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ rnn_dim: int,
21
+ bidirectional: bool = True,
22
+ rnn_type: str = "LSTM",
23
+ use_batch_trick: bool = True,
24
+ use_layer_norm: bool = True,
25
+ ) -> None:
26
+ # n_group is the size of the 2nd dim
27
+ super().__init__()
28
+
29
+ self.use_layer_norm = use_layer_norm
30
+ if use_layer_norm:
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ else:
33
+ self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
34
+
35
+ self.rnn = rnn.__dict__[rnn_type](
36
+ input_size=emb_dim,
37
+ hidden_size=rnn_dim,
38
+ num_layers=1,
39
+ batch_first=True,
40
+ bidirectional=bidirectional,
41
+ )
42
+
43
+ self.fc = nn.Linear(
44
+ in_features=rnn_dim * (2 if bidirectional else 1),
45
+ out_features=emb_dim
46
+ )
47
+
48
+ self.use_batch_trick = use_batch_trick
49
+ if not self.use_batch_trick:
50
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
51
+
52
+ def forward(self, z):
53
+ # z = (batch, n_uncrossed, n_across, emb_dim)
54
+
55
+ z0 = torch.clone(z)
56
+
57
+ # print(z.device)
58
+
59
+ if self.use_layer_norm:
60
+ z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
61
+ else:
62
+ z = torch.permute(
63
+ z, (0, 3, 1, 2)
64
+ ) # (batch, emb_dim, n_uncrossed, n_across)
65
+
66
+ z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across)
67
+
68
+ z = torch.permute(
69
+ z, (0, 2, 3, 1)
70
+ ) # (batch, n_uncrossed, n_across, emb_dim)
71
+
72
+ batch, n_uncrossed, n_across, emb_dim = z.shape
73
+
74
+ if self.use_batch_trick:
75
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
76
+
77
+ z = self.rnn(z.contiguous())[0] # (batch * n_uncrossed, n_across, dir_rnn_dim)
78
+
79
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
80
+ # (batch, n_uncrossed, n_across, dir_rnn_dim)
81
+ else:
82
+ # Note: this is EXTREMELY SLOW
83
+ zlist = []
84
+ for i in range(n_uncrossed):
85
+ zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim)
86
+ zlist.append(zi)
87
+
88
+ z = torch.stack(
89
+ zlist,
90
+ dim=1
91
+ ) # (batch, n_uncrossed, n_across, dir_rnn_dim)
92
+
93
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
94
+
95
+ z = z + z0
96
+
97
+ return z
98
+
99
+
100
+ class SeqBandModellingModule(TimeFrequencyModellingModule):
101
+ def __init__(
102
+ self,
103
+ n_modules: int = 12,
104
+ emb_dim: int = 128,
105
+ rnn_dim: int = 256,
106
+ bidirectional: bool = True,
107
+ rnn_type: str = "LSTM",
108
+ parallel_mode=False,
109
+ ) -> None:
110
+ super().__init__()
111
+ self.seqband = nn.ModuleList([])
112
+
113
+ if parallel_mode:
114
+ for _ in range(n_modules):
115
+ self.seqband.append(
116
+ nn.ModuleList(
117
+ [ResidualRNN(
118
+ emb_dim=emb_dim,
119
+ rnn_dim=rnn_dim,
120
+ bidirectional=bidirectional,
121
+ rnn_type=rnn_type,
122
+ ),
123
+ ResidualRNN(
124
+ emb_dim=emb_dim,
125
+ rnn_dim=rnn_dim,
126
+ bidirectional=bidirectional,
127
+ rnn_type=rnn_type,
128
+ )]
129
+ )
130
+ )
131
+ else:
132
+
133
+ for _ in range(2 * n_modules):
134
+ self.seqband.append(
135
+ ResidualRNN(
136
+ emb_dim=emb_dim,
137
+ rnn_dim=rnn_dim,
138
+ bidirectional=bidirectional,
139
+ rnn_type=rnn_type,
140
+ )
141
+ )
142
+
143
+ self.parallel_mode = parallel_mode
144
+
145
+ def forward(self, z):
146
+ # z = (batch, n_bands, n_time, emb_dim)
147
+
148
+ if self.parallel_mode:
149
+ for sbm_pair in self.seqband:
150
+ # z: (batch, n_bands, n_time, emb_dim)
151
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
152
+ zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
153
+ zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
154
+ z = zt + zf.transpose(1, 2)
155
+ else:
156
+ for sbm in self.seqband:
157
+ z = sbm(z)
158
+ z = z.transpose(1, 2)
159
+
160
+ # (batch, n_bands, n_time, emb_dim)
161
+ # --> (batch, n_time, n_bands, emb_dim)
162
+ # OR
163
+ # (batch, n_time, n_bands, emb_dim)
164
+ # --> (batch, n_bands, n_time, emb_dim)
165
+
166
+ q = z
167
+ return q # (batch, n_bands, n_time, emb_dim)
168
+
169
+
170
+ class ResidualTransformer(nn.Module):
171
+ def __init__(
172
+ self,
173
+ emb_dim: int = 128,
174
+ rnn_dim: int = 256,
175
+ bidirectional: bool = True,
176
+ dropout: float = 0.0,
177
+ ) -> None:
178
+ # n_group is the size of the 2nd dim
179
+ super().__init__()
180
+
181
+ self.tf = nn.TransformerEncoderLayer(
182
+ d_model=emb_dim,
183
+ nhead=4,
184
+ dim_feedforward=rnn_dim,
185
+ batch_first=True
186
+ )
187
+
188
+ self.is_causal = not bidirectional
189
+ self.dropout = dropout
190
+
191
+ def forward(self, z):
192
+ batch, n_uncrossed, n_across, emb_dim = z.shape
193
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
194
+ z = self.tf(z, is_causal=self.is_causal) # (batch, n_uncrossed, n_across, emb_dim)
195
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
196
+
197
+ return z
198
+
199
+
200
+ class TransformerTimeFreqModule(TimeFrequencyModellingModule):
201
+ def __init__(
202
+ self,
203
+ n_modules: int = 12,
204
+ emb_dim: int = 128,
205
+ rnn_dim: int = 256,
206
+ bidirectional: bool = True,
207
+ dropout: float = 0.0,
208
+ ) -> None:
209
+ super().__init__()
210
+ self.norm = nn.LayerNorm(emb_dim)
211
+ self.seqband = nn.ModuleList([])
212
+
213
+ for _ in range(2 * n_modules):
214
+ self.seqband.append(
215
+ ResidualTransformer(
216
+ emb_dim=emb_dim,
217
+ rnn_dim=rnn_dim,
218
+ bidirectional=bidirectional,
219
+ dropout=dropout,
220
+ )
221
+ )
222
+
223
+ def forward(self, z):
224
+ # z = (batch, n_bands, n_time, emb_dim)
225
+ z = self.norm(z) # (batch, n_bands, n_time, emb_dim)
226
+
227
+ for sbm in self.seqband:
228
+ z = sbm(z)
229
+ z = z.transpose(1, 2)
230
+
231
+ # (batch, n_bands, n_time, emb_dim)
232
+ # --> (batch, n_time, n_bands, emb_dim)
233
+ # OR
234
+ # (batch, n_time, n_bands, emb_dim)
235
+ # --> (batch, n_bands, n_time, emb_dim)
236
+
237
+ q = z
238
+ return q # (batch, n_bands, n_time, emb_dim)
239
+
240
+
241
+
242
+ class ResidualConvolution(nn.Module):
243
+ def __init__(
244
+ self,
245
+ emb_dim: int = 128,
246
+ rnn_dim: int = 256,
247
+ bidirectional: bool = True,
248
+ dropout: float = 0.0,
249
+ ) -> None:
250
+ # n_group is the size of the 2nd dim
251
+ super().__init__()
252
+ self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
253
+
254
+ self.conv = nn.Sequential(
255
+ nn.Conv2d(
256
+ in_channels=emb_dim,
257
+ out_channels=rnn_dim,
258
+ kernel_size=(3, 3),
259
+ padding="same",
260
+ stride=(1, 1),
261
+ ),
262
+ nn.Tanhshrink()
263
+ )
264
+
265
+ self.is_causal = not bidirectional
266
+ self.dropout = dropout
267
+
268
+ self.fc = nn.Conv2d(
269
+ in_channels=rnn_dim,
270
+ out_channels=emb_dim,
271
+ kernel_size=(1, 1),
272
+ padding="same",
273
+ stride=(1, 1),
274
+ )
275
+
276
+
277
+ def forward(self, z):
278
+ # z = (batch, n_uncrossed, n_across, emb_dim)
279
+
280
+ z0 = torch.clone(z)
281
+
282
+ z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
283
+ z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim)
284
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
285
+ z = z + z0
286
+
287
+ return z
288
+
289
+
290
+ class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
291
+ def __init__(
292
+ self,
293
+ n_modules: int = 12,
294
+ emb_dim: int = 128,
295
+ rnn_dim: int = 256,
296
+ bidirectional: bool = True,
297
+ dropout: float = 0.0,
298
+ ) -> None:
299
+ super().__init__()
300
+ self.seqband = torch.jit.script(nn.Sequential(
301
+ *[ResidualConvolution(
302
+ emb_dim=emb_dim,
303
+ rnn_dim=rnn_dim,
304
+ bidirectional=bidirectional,
305
+ dropout=dropout,
306
+ ) for _ in range(2 * n_modules) ]))
307
+
308
+ def forward(self, z):
309
+ # z = (batch, n_bands, n_time, emb_dim)
310
+
311
+ z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
312
+
313
+ z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
314
+
315
+ z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
316
+
317
+ return z
models/bandit/core/model/bsrnn/utils.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import Any, Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ from librosa import hz_to_midi, midi_to_hz
8
+ from torch import Tensor
9
+ from torchaudio import functional as taF
10
+ from spafe.fbanks import bark_fbanks
11
+ from spafe.utils.converters import erb2hz, hz2bark, hz2erb
12
+ from torchaudio.functional.functional import _create_triangular_filterbank
13
+
14
+
15
+ def band_widths_from_specs(band_specs):
16
+ return [e - i for i, e in band_specs]
17
+
18
+
19
+ def check_nonzero_bandwidth(band_specs):
20
+ # pprint(band_specs)
21
+ for fstart, fend in band_specs:
22
+ if fend - fstart <= 0:
23
+ raise ValueError("Bands cannot be zero-width")
24
+
25
+
26
+ def check_no_overlap(band_specs):
27
+ fend_prev = -1
28
+ for fstart_curr, fend_curr in band_specs:
29
+ if fstart_curr <= fend_prev:
30
+ raise ValueError("Bands cannot overlap")
31
+
32
+
33
+ def check_no_gap(band_specs):
34
+ fstart, _ = band_specs[0]
35
+ assert fstart == 0
36
+
37
+ fend_prev = -1
38
+ for fstart_curr, fend_curr in band_specs:
39
+ if fstart_curr - fend_prev > 1:
40
+ raise ValueError("Bands cannot leave gap")
41
+ fend_prev = fend_curr
42
+
43
+
44
+ class BandsplitSpecification:
45
+ def __init__(self, nfft: int, fs: int) -> None:
46
+ self.fs = fs
47
+ self.nfft = nfft
48
+ self.nyquist = fs / 2
49
+ self.max_index = nfft // 2 + 1
50
+
51
+ self.split500 = self.hertz_to_index(500)
52
+ self.split1k = self.hertz_to_index(1000)
53
+ self.split2k = self.hertz_to_index(2000)
54
+ self.split4k = self.hertz_to_index(4000)
55
+ self.split8k = self.hertz_to_index(8000)
56
+ self.split16k = self.hertz_to_index(16000)
57
+ self.split20k = self.hertz_to_index(20000)
58
+
59
+ self.above20k = [(self.split20k, self.max_index)]
60
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
61
+
62
+ def index_to_hertz(self, index: int):
63
+ return index * self.fs / self.nfft
64
+
65
+ def hertz_to_index(self, hz: float, round: bool = True):
66
+ index = hz * self.nfft / self.fs
67
+
68
+ if round:
69
+ index = int(np.round(index))
70
+
71
+ return index
72
+
73
+ def get_band_specs_with_bandwidth(
74
+ self,
75
+ start_index,
76
+ end_index,
77
+ bandwidth_hz
78
+ ):
79
+ band_specs = []
80
+ lower = start_index
81
+
82
+ while lower < end_index:
83
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
84
+ upper = min(upper, end_index)
85
+
86
+ band_specs.append((lower, upper))
87
+ lower = upper
88
+
89
+ return band_specs
90
+
91
+ @abstractmethod
92
+ def get_band_specs(self):
93
+ raise NotImplementedError
94
+
95
+
96
+ class VocalBandsplitSpecification(BandsplitSpecification):
97
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
98
+ super().__init__(nfft=nfft, fs=fs)
99
+
100
+ self.version = version
101
+
102
+ def get_band_specs(self):
103
+ return getattr(self, f"version{self.version}")()
104
+
105
+ @property
106
+ def version1(self):
107
+ return self.get_band_specs_with_bandwidth(
108
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
109
+ )
110
+
111
+ def version2(self):
112
+ below16k = self.get_band_specs_with_bandwidth(
113
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
114
+ )
115
+ below20k = self.get_band_specs_with_bandwidth(
116
+ start_index=self.split16k,
117
+ end_index=self.split20k,
118
+ bandwidth_hz=2000
119
+ )
120
+
121
+ return below16k + below20k + self.above20k
122
+
123
+ def version3(self):
124
+ below8k = self.get_band_specs_with_bandwidth(
125
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
126
+ )
127
+ below16k = self.get_band_specs_with_bandwidth(
128
+ start_index=self.split8k,
129
+ end_index=self.split16k,
130
+ bandwidth_hz=2000
131
+ )
132
+
133
+ return below8k + below16k + self.above16k
134
+
135
+ def version4(self):
136
+ below1k = self.get_band_specs_with_bandwidth(
137
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
138
+ )
139
+ below8k = self.get_band_specs_with_bandwidth(
140
+ start_index=self.split1k,
141
+ end_index=self.split8k,
142
+ bandwidth_hz=1000
143
+ )
144
+ below16k = self.get_band_specs_with_bandwidth(
145
+ start_index=self.split8k,
146
+ end_index=self.split16k,
147
+ bandwidth_hz=2000
148
+ )
149
+
150
+ return below1k + below8k + below16k + self.above16k
151
+
152
+ def version5(self):
153
+ below1k = self.get_band_specs_with_bandwidth(
154
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
155
+ )
156
+ below16k = self.get_band_specs_with_bandwidth(
157
+ start_index=self.split1k,
158
+ end_index=self.split16k,
159
+ bandwidth_hz=1000
160
+ )
161
+ below20k = self.get_band_specs_with_bandwidth(
162
+ start_index=self.split16k,
163
+ end_index=self.split20k,
164
+ bandwidth_hz=2000
165
+ )
166
+ return below1k + below16k + below20k + self.above20k
167
+
168
+ def version6(self):
169
+ below1k = self.get_band_specs_with_bandwidth(
170
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
171
+ )
172
+ below4k = self.get_band_specs_with_bandwidth(
173
+ start_index=self.split1k,
174
+ end_index=self.split4k,
175
+ bandwidth_hz=500
176
+ )
177
+ below8k = self.get_band_specs_with_bandwidth(
178
+ start_index=self.split4k,
179
+ end_index=self.split8k,
180
+ bandwidth_hz=1000
181
+ )
182
+ below16k = self.get_band_specs_with_bandwidth(
183
+ start_index=self.split8k,
184
+ end_index=self.split16k,
185
+ bandwidth_hz=2000
186
+ )
187
+ return below1k + below4k + below8k + below16k + self.above16k
188
+
189
+ def version7(self):
190
+ below1k = self.get_band_specs_with_bandwidth(
191
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
192
+ )
193
+ below4k = self.get_band_specs_with_bandwidth(
194
+ start_index=self.split1k,
195
+ end_index=self.split4k,
196
+ bandwidth_hz=250
197
+ )
198
+ below8k = self.get_band_specs_with_bandwidth(
199
+ start_index=self.split4k,
200
+ end_index=self.split8k,
201
+ bandwidth_hz=500
202
+ )
203
+ below16k = self.get_band_specs_with_bandwidth(
204
+ start_index=self.split8k,
205
+ end_index=self.split16k,
206
+ bandwidth_hz=1000
207
+ )
208
+ below20k = self.get_band_specs_with_bandwidth(
209
+ start_index=self.split16k,
210
+ end_index=self.split20k,
211
+ bandwidth_hz=2000
212
+ )
213
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
214
+
215
+
216
+ class OtherBandsplitSpecification(VocalBandsplitSpecification):
217
+ def __init__(self, nfft: int, fs: int) -> None:
218
+ super().__init__(nfft=nfft, fs=fs, version="7")
219
+
220
+
221
+ class BassBandsplitSpecification(BandsplitSpecification):
222
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
223
+ super().__init__(nfft=nfft, fs=fs)
224
+
225
+ def get_band_specs(self):
226
+ below500 = self.get_band_specs_with_bandwidth(
227
+ start_index=0, end_index=self.split500, bandwidth_hz=50
228
+ )
229
+ below1k = self.get_band_specs_with_bandwidth(
230
+ start_index=self.split500,
231
+ end_index=self.split1k,
232
+ bandwidth_hz=100
233
+ )
234
+ below4k = self.get_band_specs_with_bandwidth(
235
+ start_index=self.split1k,
236
+ end_index=self.split4k,
237
+ bandwidth_hz=500
238
+ )
239
+ below8k = self.get_band_specs_with_bandwidth(
240
+ start_index=self.split4k,
241
+ end_index=self.split8k,
242
+ bandwidth_hz=1000
243
+ )
244
+ below16k = self.get_band_specs_with_bandwidth(
245
+ start_index=self.split8k,
246
+ end_index=self.split16k,
247
+ bandwidth_hz=2000
248
+ )
249
+ above16k = [(self.split16k, self.max_index)]
250
+
251
+ return below500 + below1k + below4k + below8k + below16k + above16k
252
+
253
+
254
+ class DrumBandsplitSpecification(BandsplitSpecification):
255
+ def __init__(self, nfft: int, fs: int) -> None:
256
+ super().__init__(nfft=nfft, fs=fs)
257
+
258
+ def get_band_specs(self):
259
+ below1k = self.get_band_specs_with_bandwidth(
260
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
261
+ )
262
+ below2k = self.get_band_specs_with_bandwidth(
263
+ start_index=self.split1k,
264
+ end_index=self.split2k,
265
+ bandwidth_hz=100
266
+ )
267
+ below4k = self.get_band_specs_with_bandwidth(
268
+ start_index=self.split2k,
269
+ end_index=self.split4k,
270
+ bandwidth_hz=250
271
+ )
272
+ below8k = self.get_band_specs_with_bandwidth(
273
+ start_index=self.split4k,
274
+ end_index=self.split8k,
275
+ bandwidth_hz=500
276
+ )
277
+ below16k = self.get_band_specs_with_bandwidth(
278
+ start_index=self.split8k,
279
+ end_index=self.split16k,
280
+ bandwidth_hz=1000
281
+ )
282
+ above16k = [(self.split16k, self.max_index)]
283
+
284
+ return below1k + below2k + below4k + below8k + below16k + above16k
285
+
286
+
287
+
288
+
289
+ class PerceptualBandsplitSpecification(BandsplitSpecification):
290
+ def __init__(
291
+ self,
292
+ nfft: int,
293
+ fs: int,
294
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
295
+ n_bands: int,
296
+ f_min: float = 0.0,
297
+ f_max: float = None
298
+ ) -> None:
299
+ super().__init__(nfft=nfft, fs=fs)
300
+ self.n_bands = n_bands
301
+ if f_max is None:
302
+ f_max = fs / 2
303
+
304
+ self.filterbank = fbank_fn(
305
+ n_bands, fs, f_min, f_max, self.max_index
306
+ )
307
+
308
+ weight_per_bin = torch.sum(
309
+ self.filterbank,
310
+ dim=0,
311
+ keepdim=True
312
+ ) # (1, n_freqs)
313
+ normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
314
+
315
+ freq_weights = []
316
+ band_specs = []
317
+ for i in range(self.n_bands):
318
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
319
+ if isinstance(active_bins, int):
320
+ active_bins = (active_bins, active_bins)
321
+ if len(active_bins) == 0:
322
+ continue
323
+ start_index = active_bins[0]
324
+ end_index = active_bins[-1] + 1
325
+ band_specs.append((start_index, end_index))
326
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
327
+
328
+ self.freq_weights = freq_weights
329
+ self.band_specs = band_specs
330
+
331
+ def get_band_specs(self):
332
+ return self.band_specs
333
+
334
+ def get_freq_weights(self):
335
+ return self.freq_weights
336
+
337
+ def save_to_file(self, dir_path: str) -> None:
338
+
339
+ os.makedirs(dir_path, exist_ok=True)
340
+
341
+ import pickle
342
+
343
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
344
+ pickle.dump(
345
+ {
346
+ "band_specs": self.band_specs,
347
+ "freq_weights": self.freq_weights,
348
+ "filterbank": self.filterbank,
349
+ },
350
+ f,
351
+ )
352
+
353
+ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
354
+ fb = taF.melscale_fbanks(
355
+ n_mels=n_bands,
356
+ sample_rate=fs,
357
+ f_min=f_min,
358
+ f_max=f_max,
359
+ n_freqs=n_freqs,
360
+ ).T
361
+
362
+ fb[0, 0] = 1.0
363
+
364
+ return fb
365
+
366
+
367
+ class MelBandsplitSpecification(PerceptualBandsplitSpecification):
368
+ def __init__(
369
+ self,
370
+ nfft: int,
371
+ fs: int,
372
+ n_bands: int,
373
+ f_min: float = 0.0,
374
+ f_max: float = None
375
+ ) -> None:
376
+ super().__init__(fbank_fn=mel_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
377
+
378
+ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs,
379
+ scale="constant"):
380
+
381
+ nfft = 2 * (n_freqs - 1)
382
+ df = fs / nfft
383
+ # init freqs
384
+ f_max = f_max or fs / 2
385
+ f_min = f_min or 0
386
+ f_min = fs / nfft
387
+
388
+ n_octaves = np.log2(f_max / f_min)
389
+ n_octaves_per_band = n_octaves / n_bands
390
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
391
+
392
+ low_midi = max(0, hz_to_midi(f_min))
393
+ high_midi = hz_to_midi(f_max)
394
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
395
+ hz_pts = midi_to_hz(midi_points)
396
+
397
+ low_pts = hz_pts / bandwidth_mult
398
+ high_pts = hz_pts * bandwidth_mult
399
+
400
+ low_bins = np.floor(low_pts / df).astype(int)
401
+ high_bins = np.ceil(high_pts / df).astype(int)
402
+
403
+ fb = np.zeros((n_bands, n_freqs))
404
+
405
+ for i in range(n_bands):
406
+ fb[i, low_bins[i]:high_bins[i]+1] = 1.0
407
+
408
+ fb[0, :low_bins[0]] = 1.0
409
+ fb[-1, high_bins[-1]+1:] = 1.0
410
+
411
+ return torch.as_tensor(fb)
412
+
413
+ class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
414
+ def __init__(
415
+ self,
416
+ nfft: int,
417
+ fs: int,
418
+ n_bands: int,
419
+ f_min: float = 0.0,
420
+ f_max: float = None
421
+ ) -> None:
422
+ super().__init__(fbank_fn=musical_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
423
+
424
+
425
+ def bark_filterbank(
426
+ n_bands, fs, f_min, f_max, n_freqs
427
+ ):
428
+ nfft = 2 * (n_freqs -1)
429
+ fb, _ = bark_fbanks.bark_filter_banks(
430
+ nfilts=n_bands,
431
+ nfft=nfft,
432
+ fs=fs,
433
+ low_freq=f_min,
434
+ high_freq=f_max,
435
+ scale="constant"
436
+ )
437
+
438
+ return torch.as_tensor(fb)
439
+
440
+ class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
441
+ def __init__(
442
+ self,
443
+ nfft: int,
444
+ fs: int,
445
+ n_bands: int,
446
+ f_min: float = 0.0,
447
+ f_max: float = None
448
+ ) -> None:
449
+ super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
450
+
451
+
452
+ def triangular_bark_filterbank(
453
+ n_bands, fs, f_min, f_max, n_freqs
454
+ ):
455
+
456
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
457
+
458
+ # calculate mel freq bins
459
+ m_min = hz2bark(f_min)
460
+ m_max = hz2bark(f_max)
461
+
462
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
463
+ f_pts = 600 * torch.sinh(m_pts / 6)
464
+
465
+ # create filterbank
466
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
467
+
468
+ fb = fb.T
469
+
470
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
471
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
472
+
473
+ fb[first_active_band, :first_active_bin] = 1.0
474
+
475
+ return fb
476
+
477
+ class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
478
+ def __init__(
479
+ self,
480
+ nfft: int,
481
+ fs: int,
482
+ n_bands: int,
483
+ f_min: float = 0.0,
484
+ f_max: float = None
485
+ ) -> None:
486
+ super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
487
+
488
+
489
+
490
+ def minibark_filterbank(
491
+ n_bands, fs, f_min, f_max, n_freqs
492
+ ):
493
+ fb = bark_filterbank(
494
+ n_bands,
495
+ fs,
496
+ f_min,
497
+ f_max,
498
+ n_freqs
499
+ )
500
+
501
+ fb[fb < np.sqrt(0.5)] = 0.0
502
+
503
+ return fb
504
+
505
+ class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
506
+ def __init__(
507
+ self,
508
+ nfft: int,
509
+ fs: int,
510
+ n_bands: int,
511
+ f_min: float = 0.0,
512
+ f_max: float = None
513
+ ) -> None:
514
+ super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
515
+
516
+
517
+
518
+
519
+
520
+ def erb_filterbank(
521
+ n_bands: int,
522
+ fs: int,
523
+ f_min: float,
524
+ f_max: float,
525
+ n_freqs: int,
526
+ ) -> Tensor:
527
+ # freq bins
528
+ A = (1000 * np.log(10)) / (24.7 * 4.37)
529
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
530
+
531
+ # calculate mel freq bins
532
+ m_min = hz2erb(f_min)
533
+ m_max = hz2erb(f_max)
534
+
535
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
536
+ f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
537
+
538
+ # create filterbank
539
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
540
+
541
+ fb = fb.T
542
+
543
+
544
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
545
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
546
+
547
+ fb[first_active_band, :first_active_bin] = 1.0
548
+
549
+ return fb
550
+
551
+
552
+
553
+ class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
554
+ def __init__(
555
+ self,
556
+ nfft: int,
557
+ fs: int,
558
+ n_bands: int,
559
+ f_min: float = 0.0,
560
+ f_max: float = None
561
+ ) -> None:
562
+ super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
563
+
564
+ if __name__ == "__main__":
565
+ import pandas as pd
566
+
567
+ band_defs = []
568
+
569
+ for bands in [VocalBandsplitSpecification]:
570
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
571
+
572
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
573
+
574
+ for i, (f_min, f_max) in enumerate(mbs):
575
+ band_defs.append({
576
+ "band": band_name,
577
+ "band_index": i,
578
+ "f_min": f_min,
579
+ "f_max": f_max
580
+ })
581
+
582
+ df = pd.DataFrame(band_defs)
583
+ df.to_csv("vox7bands.csv", index=False)
models/bandit/core/model/bsrnn/wrapper.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from models.bandit.core.model._spectral import _SpectralComponent
8
+ from models.bandit.core.model.bsrnn.utils import (
9
+ BarkBandsplitSpecification, BassBandsplitSpecification,
10
+ DrumBandsplitSpecification,
11
+ EquivalentRectangularBandsplitSpecification, MelBandsplitSpecification,
12
+ MusicalBandsplitSpecification, OtherBandsplitSpecification,
13
+ TriangularBarkBandsplitSpecification, VocalBandsplitSpecification,
14
+ )
15
+ from .core import (
16
+ MultiSourceMultiMaskBandSplitCoreConv,
17
+ MultiSourceMultiMaskBandSplitCoreRNN,
18
+ MultiSourceMultiMaskBandSplitCoreTransformer,
19
+ MultiSourceMultiPatchingMaskBandSplitCoreRNN, SingleMaskBandsplitCoreRNN,
20
+ SingleMaskBandsplitCoreTransformer,
21
+ )
22
+
23
+ import pytorch_lightning as pl
24
+
25
+ def get_band_specs(band_specs, n_fft, fs, n_bands=None):
26
+ if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
27
+ bsm = VocalBandsplitSpecification(
28
+ nfft=n_fft, fs=fs
29
+ ).get_band_specs()
30
+ freq_weights = None
31
+ overlapping_band = False
32
+ elif "tribark" in band_specs:
33
+ assert n_bands is not None
34
+ specs = TriangularBarkBandsplitSpecification(
35
+ nfft=n_fft,
36
+ fs=fs,
37
+ n_bands=n_bands
38
+ )
39
+ bsm = specs.get_band_specs()
40
+ freq_weights = specs.get_freq_weights()
41
+ overlapping_band = True
42
+ elif "bark" in band_specs:
43
+ assert n_bands is not None
44
+ specs = BarkBandsplitSpecification(
45
+ nfft=n_fft,
46
+ fs=fs,
47
+ n_bands=n_bands
48
+ )
49
+ bsm = specs.get_band_specs()
50
+ freq_weights = specs.get_freq_weights()
51
+ overlapping_band = True
52
+ elif "erb" in band_specs:
53
+ assert n_bands is not None
54
+ specs = EquivalentRectangularBandsplitSpecification(
55
+ nfft=n_fft,
56
+ fs=fs,
57
+ n_bands=n_bands
58
+ )
59
+ bsm = specs.get_band_specs()
60
+ freq_weights = specs.get_freq_weights()
61
+ overlapping_band = True
62
+ elif "musical" in band_specs:
63
+ assert n_bands is not None
64
+ specs = MusicalBandsplitSpecification(
65
+ nfft=n_fft,
66
+ fs=fs,
67
+ n_bands=n_bands
68
+ )
69
+ bsm = specs.get_band_specs()
70
+ freq_weights = specs.get_freq_weights()
71
+ overlapping_band = True
72
+ elif band_specs == "dnr:mel" or "mel" in band_specs:
73
+ assert n_bands is not None
74
+ specs = MelBandsplitSpecification(
75
+ nfft=n_fft,
76
+ fs=fs,
77
+ n_bands=n_bands
78
+ )
79
+ bsm = specs.get_band_specs()
80
+ freq_weights = specs.get_freq_weights()
81
+ overlapping_band = True
82
+ else:
83
+ raise NameError
84
+
85
+ return bsm, freq_weights, overlapping_band
86
+
87
+
88
+ def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
89
+ if band_specs_map == "musdb:all":
90
+ bsm = {
91
+ "vocals": VocalBandsplitSpecification(
92
+ nfft=n_fft, fs=fs
93
+ ).get_band_specs(),
94
+ "drums": DrumBandsplitSpecification(
95
+ nfft=n_fft, fs=fs
96
+ ).get_band_specs(),
97
+ "bass": BassBandsplitSpecification(
98
+ nfft=n_fft, fs=fs
99
+ ).get_band_specs(),
100
+ "other": OtherBandsplitSpecification(
101
+ nfft=n_fft, fs=fs
102
+ ).get_band_specs(),
103
+ }
104
+ freq_weights = None
105
+ overlapping_band = False
106
+ elif band_specs_map == "dnr:vox7":
107
+ bsm_, freq_weights, overlapping_band = get_band_specs(
108
+ "dnr:speech", n_fft, fs, n_bands
109
+ )
110
+ bsm = {
111
+ "speech": bsm_,
112
+ "music": bsm_,
113
+ "effects": bsm_
114
+ }
115
+ elif "dnr:vox7:" in band_specs_map:
116
+ stem = band_specs_map.split(":")[-1]
117
+ bsm_, freq_weights, overlapping_band = get_band_specs(
118
+ "dnr:speech", n_fft, fs, n_bands
119
+ )
120
+ bsm = {
121
+ stem: bsm_
122
+ }
123
+ else:
124
+ raise NameError
125
+
126
+ return bsm, freq_weights, overlapping_band
127
+
128
+
129
+ class BandSplitWrapperBase(pl.LightningModule):
130
+ bsrnn: nn.Module
131
+
132
+ def __init__(self, **kwargs):
133
+ super().__init__()
134
+
135
+
136
+ class SingleMaskMultiSourceBandSplitBase(
137
+ BandSplitWrapperBase,
138
+ _SpectralComponent
139
+ ):
140
+ def __init__(
141
+ self,
142
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
143
+ fs: int = 44100,
144
+ n_fft: int = 2048,
145
+ win_length: Optional[int] = 2048,
146
+ hop_length: int = 512,
147
+ window_fn: str = "hann_window",
148
+ wkwargs: Optional[Dict] = None,
149
+ power: Optional[int] = None,
150
+ center: bool = True,
151
+ normalized: bool = True,
152
+ pad_mode: str = "constant",
153
+ onesided: bool = True,
154
+ n_bands: int = None,
155
+ ) -> None:
156
+ super().__init__(
157
+ n_fft=n_fft,
158
+ win_length=win_length,
159
+ hop_length=hop_length,
160
+ window_fn=window_fn,
161
+ wkwargs=wkwargs,
162
+ power=power,
163
+ center=center,
164
+ normalized=normalized,
165
+ pad_mode=pad_mode,
166
+ onesided=onesided,
167
+ )
168
+
169
+ if isinstance(band_specs_map, str):
170
+ self.band_specs_map, self.freq_weights, self.overlapping_band = get_band_specs_map(
171
+ band_specs_map,
172
+ n_fft,
173
+ fs,
174
+ n_bands=n_bands
175
+ )
176
+
177
+ self.stems = list(self.band_specs_map.keys())
178
+
179
+ def forward(self, batch):
180
+ audio = batch["audio"]
181
+
182
+ with torch.no_grad():
183
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
184
+ audio}
185
+
186
+ X = batch["spectrogram"]["mixture"]
187
+ length = batch["audio"]["mixture"].shape[-1]
188
+
189
+ output = {"spectrogram": {}, "audio": {}}
190
+
191
+ for stem, bsrnn in self.bsrnn.items():
192
+ S = bsrnn(X)
193
+ s = self.istft(S, length)
194
+ output["spectrogram"][stem] = S
195
+ output["audio"][stem] = s
196
+
197
+ return batch, output
198
+
199
+
200
+ class MultiMaskMultiSourceBandSplitBase(
201
+ BandSplitWrapperBase,
202
+ _SpectralComponent
203
+ ):
204
+ def __init__(
205
+ self,
206
+ stems: List[str],
207
+ band_specs: Union[str, List[Tuple[float, float]]],
208
+ fs: int = 44100,
209
+ n_fft: int = 2048,
210
+ win_length: Optional[int] = 2048,
211
+ hop_length: int = 512,
212
+ window_fn: str = "hann_window",
213
+ wkwargs: Optional[Dict] = None,
214
+ power: Optional[int] = None,
215
+ center: bool = True,
216
+ normalized: bool = True,
217
+ pad_mode: str = "constant",
218
+ onesided: bool = True,
219
+ n_bands: int = None,
220
+ ) -> None:
221
+ super().__init__(
222
+ n_fft=n_fft,
223
+ win_length=win_length,
224
+ hop_length=hop_length,
225
+ window_fn=window_fn,
226
+ wkwargs=wkwargs,
227
+ power=power,
228
+ center=center,
229
+ normalized=normalized,
230
+ pad_mode=pad_mode,
231
+ onesided=onesided,
232
+ )
233
+
234
+ if isinstance(band_specs, str):
235
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
236
+ band_specs,
237
+ n_fft,
238
+ fs,
239
+ n_bands
240
+ )
241
+
242
+ self.stems = stems
243
+
244
+ def forward(self, batch):
245
+ # with torch.no_grad():
246
+ audio = batch["audio"]
247
+ cond = batch.get("condition", None)
248
+ with torch.no_grad():
249
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
250
+ audio}
251
+
252
+ X = batch["spectrogram"]["mixture"]
253
+ length = batch["audio"]["mixture"].shape[-1]
254
+
255
+ output = self.bsrnn(X, cond=cond)
256
+ output["audio"] = {}
257
+
258
+ for stem, S in output["spectrogram"].items():
259
+ s = self.istft(S, length)
260
+ output["audio"][stem] = s
261
+
262
+ return batch, output
263
+
264
+
265
+ class MultiMaskMultiSourceBandSplitBaseSimple(
266
+ BandSplitWrapperBase,
267
+ _SpectralComponent
268
+ ):
269
+ def __init__(
270
+ self,
271
+ stems: List[str],
272
+ band_specs: Union[str, List[Tuple[float, float]]],
273
+ fs: int = 44100,
274
+ n_fft: int = 2048,
275
+ win_length: Optional[int] = 2048,
276
+ hop_length: int = 512,
277
+ window_fn: str = "hann_window",
278
+ wkwargs: Optional[Dict] = None,
279
+ power: Optional[int] = None,
280
+ center: bool = True,
281
+ normalized: bool = True,
282
+ pad_mode: str = "constant",
283
+ onesided: bool = True,
284
+ n_bands: int = None,
285
+ ) -> None:
286
+ super().__init__(
287
+ n_fft=n_fft,
288
+ win_length=win_length,
289
+ hop_length=hop_length,
290
+ window_fn=window_fn,
291
+ wkwargs=wkwargs,
292
+ power=power,
293
+ center=center,
294
+ normalized=normalized,
295
+ pad_mode=pad_mode,
296
+ onesided=onesided,
297
+ )
298
+
299
+ if isinstance(band_specs, str):
300
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
301
+ band_specs,
302
+ n_fft,
303
+ fs,
304
+ n_bands
305
+ )
306
+
307
+ self.stems = stems
308
+
309
+ def forward(self, batch):
310
+ with torch.no_grad():
311
+ X = self.stft(batch)
312
+ length = batch.shape[-1]
313
+ output = self.bsrnn(X, cond=None)
314
+ res = []
315
+ for stem, S in output["spectrogram"].items():
316
+ s = self.istft(S, length)
317
+ res.append(s)
318
+ res = torch.stack(res, dim=1)
319
+ return res
320
+
321
+
322
+ class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
323
+ def __init__(
324
+ self,
325
+ in_channel: int,
326
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
327
+ fs: int = 44100,
328
+ require_no_overlap: bool = False,
329
+ require_no_gap: bool = True,
330
+ normalize_channel_independently: bool = False,
331
+ treat_channel_as_feature: bool = True,
332
+ n_sqm_modules: int = 12,
333
+ emb_dim: int = 128,
334
+ rnn_dim: int = 256,
335
+ bidirectional: bool = True,
336
+ rnn_type: str = "LSTM",
337
+ mlp_dim: int = 512,
338
+ hidden_activation: str = "Tanh",
339
+ hidden_activation_kwargs: Optional[Dict] = None,
340
+ complex_mask: bool = True,
341
+ n_fft: int = 2048,
342
+ win_length: Optional[int] = 2048,
343
+ hop_length: int = 512,
344
+ window_fn: str = "hann_window",
345
+ wkwargs: Optional[Dict] = None,
346
+ power: Optional[int] = None,
347
+ center: bool = True,
348
+ normalized: bool = True,
349
+ pad_mode: str = "constant",
350
+ onesided: bool = True,
351
+ ) -> None:
352
+ super().__init__(
353
+ band_specs_map=band_specs_map,
354
+ fs=fs,
355
+ n_fft=n_fft,
356
+ win_length=win_length,
357
+ hop_length=hop_length,
358
+ window_fn=window_fn,
359
+ wkwargs=wkwargs,
360
+ power=power,
361
+ center=center,
362
+ normalized=normalized,
363
+ pad_mode=pad_mode,
364
+ onesided=onesided,
365
+ )
366
+
367
+ self.bsrnn = nn.ModuleDict(
368
+ {
369
+ src: SingleMaskBandsplitCoreRNN(
370
+ band_specs=specs,
371
+ in_channel=in_channel,
372
+ require_no_overlap=require_no_overlap,
373
+ require_no_gap=require_no_gap,
374
+ normalize_channel_independently=normalize_channel_independently,
375
+ treat_channel_as_feature=treat_channel_as_feature,
376
+ n_sqm_modules=n_sqm_modules,
377
+ emb_dim=emb_dim,
378
+ rnn_dim=rnn_dim,
379
+ bidirectional=bidirectional,
380
+ rnn_type=rnn_type,
381
+ mlp_dim=mlp_dim,
382
+ hidden_activation=hidden_activation,
383
+ hidden_activation_kwargs=hidden_activation_kwargs,
384
+ complex_mask=complex_mask,
385
+ )
386
+ for src, specs in self.band_specs_map.items()
387
+ }
388
+ )
389
+
390
+
391
+ class SingleMaskMultiSourceBandSplitTransformer(
392
+ SingleMaskMultiSourceBandSplitBase
393
+ ):
394
+ def __init__(
395
+ self,
396
+ in_channel: int,
397
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
398
+ fs: int = 44100,
399
+ require_no_overlap: bool = False,
400
+ require_no_gap: bool = True,
401
+ normalize_channel_independently: bool = False,
402
+ treat_channel_as_feature: bool = True,
403
+ n_sqm_modules: int = 12,
404
+ emb_dim: int = 128,
405
+ rnn_dim: int = 256,
406
+ bidirectional: bool = True,
407
+ tf_dropout: float = 0.0,
408
+ mlp_dim: int = 512,
409
+ hidden_activation: str = "Tanh",
410
+ hidden_activation_kwargs: Optional[Dict] = None,
411
+ complex_mask: bool = True,
412
+ n_fft: int = 2048,
413
+ win_length: Optional[int] = 2048,
414
+ hop_length: int = 512,
415
+ window_fn: str = "hann_window",
416
+ wkwargs: Optional[Dict] = None,
417
+ power: Optional[int] = None,
418
+ center: bool = True,
419
+ normalized: bool = True,
420
+ pad_mode: str = "constant",
421
+ onesided: bool = True,
422
+ ) -> None:
423
+ super().__init__(
424
+ band_specs_map=band_specs_map,
425
+ fs=fs,
426
+ n_fft=n_fft,
427
+ win_length=win_length,
428
+ hop_length=hop_length,
429
+ window_fn=window_fn,
430
+ wkwargs=wkwargs,
431
+ power=power,
432
+ center=center,
433
+ normalized=normalized,
434
+ pad_mode=pad_mode,
435
+ onesided=onesided,
436
+ )
437
+
438
+ self.bsrnn = nn.ModuleDict(
439
+ {
440
+ src: SingleMaskBandsplitCoreTransformer(
441
+ band_specs=specs,
442
+ in_channel=in_channel,
443
+ require_no_overlap=require_no_overlap,
444
+ require_no_gap=require_no_gap,
445
+ normalize_channel_independently=normalize_channel_independently,
446
+ treat_channel_as_feature=treat_channel_as_feature,
447
+ n_sqm_modules=n_sqm_modules,
448
+ emb_dim=emb_dim,
449
+ rnn_dim=rnn_dim,
450
+ bidirectional=bidirectional,
451
+ tf_dropout=tf_dropout,
452
+ mlp_dim=mlp_dim,
453
+ hidden_activation=hidden_activation,
454
+ hidden_activation_kwargs=hidden_activation_kwargs,
455
+ complex_mask=complex_mask,
456
+ )
457
+ for src, specs in self.band_specs_map.items()
458
+ }
459
+ )
460
+
461
+
462
+ class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
463
+ def __init__(
464
+ self,
465
+ in_channel: int,
466
+ stems: List[str],
467
+ band_specs: Union[str, List[Tuple[float, float]]],
468
+ fs: int = 44100,
469
+ require_no_overlap: bool = False,
470
+ require_no_gap: bool = True,
471
+ normalize_channel_independently: bool = False,
472
+ treat_channel_as_feature: bool = True,
473
+ n_sqm_modules: int = 12,
474
+ emb_dim: int = 128,
475
+ rnn_dim: int = 256,
476
+ cond_dim: int = 0,
477
+ bidirectional: bool = True,
478
+ rnn_type: str = "LSTM",
479
+ mlp_dim: int = 512,
480
+ hidden_activation: str = "Tanh",
481
+ hidden_activation_kwargs: Optional[Dict] = None,
482
+ complex_mask: bool = True,
483
+ n_fft: int = 2048,
484
+ win_length: Optional[int] = 2048,
485
+ hop_length: int = 512,
486
+ window_fn: str = "hann_window",
487
+ wkwargs: Optional[Dict] = None,
488
+ power: Optional[int] = None,
489
+ center: bool = True,
490
+ normalized: bool = True,
491
+ pad_mode: str = "constant",
492
+ onesided: bool = True,
493
+ n_bands: int = None,
494
+ use_freq_weights: bool = True,
495
+ normalize_input: bool = False,
496
+ mult_add_mask: bool = False,
497
+ freeze_encoder: bool = False,
498
+ ) -> None:
499
+ super().__init__(
500
+ stems=stems,
501
+ band_specs=band_specs,
502
+ fs=fs,
503
+ n_fft=n_fft,
504
+ win_length=win_length,
505
+ hop_length=hop_length,
506
+ window_fn=window_fn,
507
+ wkwargs=wkwargs,
508
+ power=power,
509
+ center=center,
510
+ normalized=normalized,
511
+ pad_mode=pad_mode,
512
+ onesided=onesided,
513
+ n_bands=n_bands,
514
+ )
515
+
516
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
517
+ stems=stems,
518
+ band_specs=self.band_specs,
519
+ in_channel=in_channel,
520
+ require_no_overlap=require_no_overlap,
521
+ require_no_gap=require_no_gap,
522
+ normalize_channel_independently=normalize_channel_independently,
523
+ treat_channel_as_feature=treat_channel_as_feature,
524
+ n_sqm_modules=n_sqm_modules,
525
+ emb_dim=emb_dim,
526
+ rnn_dim=rnn_dim,
527
+ bidirectional=bidirectional,
528
+ rnn_type=rnn_type,
529
+ mlp_dim=mlp_dim,
530
+ cond_dim=cond_dim,
531
+ hidden_activation=hidden_activation,
532
+ hidden_activation_kwargs=hidden_activation_kwargs,
533
+ complex_mask=complex_mask,
534
+ overlapping_band=self.overlapping_band,
535
+ freq_weights=self.freq_weights,
536
+ n_freq=n_fft // 2 + 1,
537
+ use_freq_weights=use_freq_weights,
538
+ mult_add_mask=mult_add_mask
539
+ )
540
+
541
+ self.normalize_input = normalize_input
542
+ self.cond_dim = cond_dim
543
+
544
+ if freeze_encoder:
545
+ for param in self.bsrnn.band_split.parameters():
546
+ param.requires_grad = False
547
+
548
+ for param in self.bsrnn.tf_model.parameters():
549
+ param.requires_grad = False
550
+
551
+
552
+ class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
553
+ def __init__(
554
+ self,
555
+ in_channel: int,
556
+ stems: List[str],
557
+ band_specs: Union[str, List[Tuple[float, float]]],
558
+ fs: int = 44100,
559
+ require_no_overlap: bool = False,
560
+ require_no_gap: bool = True,
561
+ normalize_channel_independently: bool = False,
562
+ treat_channel_as_feature: bool = True,
563
+ n_sqm_modules: int = 12,
564
+ emb_dim: int = 128,
565
+ rnn_dim: int = 256,
566
+ cond_dim: int = 0,
567
+ bidirectional: bool = True,
568
+ rnn_type: str = "LSTM",
569
+ mlp_dim: int = 512,
570
+ hidden_activation: str = "Tanh",
571
+ hidden_activation_kwargs: Optional[Dict] = None,
572
+ complex_mask: bool = True,
573
+ n_fft: int = 2048,
574
+ win_length: Optional[int] = 2048,
575
+ hop_length: int = 512,
576
+ window_fn: str = "hann_window",
577
+ wkwargs: Optional[Dict] = None,
578
+ power: Optional[int] = None,
579
+ center: bool = True,
580
+ normalized: bool = True,
581
+ pad_mode: str = "constant",
582
+ onesided: bool = True,
583
+ n_bands: int = None,
584
+ use_freq_weights: bool = True,
585
+ normalize_input: bool = False,
586
+ mult_add_mask: bool = False,
587
+ freeze_encoder: bool = False,
588
+ ) -> None:
589
+ super().__init__(
590
+ stems=stems,
591
+ band_specs=band_specs,
592
+ fs=fs,
593
+ n_fft=n_fft,
594
+ win_length=win_length,
595
+ hop_length=hop_length,
596
+ window_fn=window_fn,
597
+ wkwargs=wkwargs,
598
+ power=power,
599
+ center=center,
600
+ normalized=normalized,
601
+ pad_mode=pad_mode,
602
+ onesided=onesided,
603
+ n_bands=n_bands,
604
+ )
605
+
606
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
607
+ stems=stems,
608
+ band_specs=self.band_specs,
609
+ in_channel=in_channel,
610
+ require_no_overlap=require_no_overlap,
611
+ require_no_gap=require_no_gap,
612
+ normalize_channel_independently=normalize_channel_independently,
613
+ treat_channel_as_feature=treat_channel_as_feature,
614
+ n_sqm_modules=n_sqm_modules,
615
+ emb_dim=emb_dim,
616
+ rnn_dim=rnn_dim,
617
+ bidirectional=bidirectional,
618
+ rnn_type=rnn_type,
619
+ mlp_dim=mlp_dim,
620
+ cond_dim=cond_dim,
621
+ hidden_activation=hidden_activation,
622
+ hidden_activation_kwargs=hidden_activation_kwargs,
623
+ complex_mask=complex_mask,
624
+ overlapping_band=self.overlapping_band,
625
+ freq_weights=self.freq_weights,
626
+ n_freq=n_fft // 2 + 1,
627
+ use_freq_weights=use_freq_weights,
628
+ mult_add_mask=mult_add_mask
629
+ )
630
+
631
+ self.normalize_input = normalize_input
632
+ self.cond_dim = cond_dim
633
+
634
+ if freeze_encoder:
635
+ for param in self.bsrnn.band_split.parameters():
636
+ param.requires_grad = False
637
+
638
+ for param in self.bsrnn.tf_model.parameters():
639
+ param.requires_grad = False
640
+
641
+
642
+ class MultiMaskMultiSourceBandSplitTransformer(
643
+ MultiMaskMultiSourceBandSplitBase
644
+ ):
645
+ def __init__(
646
+ self,
647
+ in_channel: int,
648
+ stems: List[str],
649
+ band_specs: Union[str, List[Tuple[float, float]]],
650
+ fs: int = 44100,
651
+ require_no_overlap: bool = False,
652
+ require_no_gap: bool = True,
653
+ normalize_channel_independently: bool = False,
654
+ treat_channel_as_feature: bool = True,
655
+ n_sqm_modules: int = 12,
656
+ emb_dim: int = 128,
657
+ rnn_dim: int = 256,
658
+ cond_dim: int = 0,
659
+ bidirectional: bool = True,
660
+ rnn_type: str = "LSTM",
661
+ mlp_dim: int = 512,
662
+ hidden_activation: str = "Tanh",
663
+ hidden_activation_kwargs: Optional[Dict] = None,
664
+ complex_mask: bool = True,
665
+ n_fft: int = 2048,
666
+ win_length: Optional[int] = 2048,
667
+ hop_length: int = 512,
668
+ window_fn: str = "hann_window",
669
+ wkwargs: Optional[Dict] = None,
670
+ power: Optional[int] = None,
671
+ center: bool = True,
672
+ normalized: bool = True,
673
+ pad_mode: str = "constant",
674
+ onesided: bool = True,
675
+ n_bands: int = None,
676
+ use_freq_weights: bool = True,
677
+ normalize_input: bool = False,
678
+ mult_add_mask: bool = False
679
+ ) -> None:
680
+ super().__init__(
681
+ stems=stems,
682
+ band_specs=band_specs,
683
+ fs=fs,
684
+ n_fft=n_fft,
685
+ win_length=win_length,
686
+ hop_length=hop_length,
687
+ window_fn=window_fn,
688
+ wkwargs=wkwargs,
689
+ power=power,
690
+ center=center,
691
+ normalized=normalized,
692
+ pad_mode=pad_mode,
693
+ onesided=onesided,
694
+ n_bands=n_bands,
695
+ )
696
+
697
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
698
+ stems=stems,
699
+ band_specs=self.band_specs,
700
+ in_channel=in_channel,
701
+ require_no_overlap=require_no_overlap,
702
+ require_no_gap=require_no_gap,
703
+ normalize_channel_independently=normalize_channel_independently,
704
+ treat_channel_as_feature=treat_channel_as_feature,
705
+ n_sqm_modules=n_sqm_modules,
706
+ emb_dim=emb_dim,
707
+ rnn_dim=rnn_dim,
708
+ bidirectional=bidirectional,
709
+ rnn_type=rnn_type,
710
+ mlp_dim=mlp_dim,
711
+ cond_dim=cond_dim,
712
+ hidden_activation=hidden_activation,
713
+ hidden_activation_kwargs=hidden_activation_kwargs,
714
+ complex_mask=complex_mask,
715
+ overlapping_band=self.overlapping_band,
716
+ freq_weights=self.freq_weights,
717
+ n_freq=n_fft // 2 + 1,
718
+ use_freq_weights=use_freq_weights,
719
+ mult_add_mask=mult_add_mask
720
+ )
721
+
722
+
723
+
724
+ class MultiMaskMultiSourceBandSplitConv(
725
+ MultiMaskMultiSourceBandSplitBase
726
+ ):
727
+ def __init__(
728
+ self,
729
+ in_channel: int,
730
+ stems: List[str],
731
+ band_specs: Union[str, List[Tuple[float, float]]],
732
+ fs: int = 44100,
733
+ require_no_overlap: bool = False,
734
+ require_no_gap: bool = True,
735
+ normalize_channel_independently: bool = False,
736
+ treat_channel_as_feature: bool = True,
737
+ n_sqm_modules: int = 12,
738
+ emb_dim: int = 128,
739
+ rnn_dim: int = 256,
740
+ cond_dim: int = 0,
741
+ bidirectional: bool = True,
742
+ rnn_type: str = "LSTM",
743
+ mlp_dim: int = 512,
744
+ hidden_activation: str = "Tanh",
745
+ hidden_activation_kwargs: Optional[Dict] = None,
746
+ complex_mask: bool = True,
747
+ n_fft: int = 2048,
748
+ win_length: Optional[int] = 2048,
749
+ hop_length: int = 512,
750
+ window_fn: str = "hann_window",
751
+ wkwargs: Optional[Dict] = None,
752
+ power: Optional[int] = None,
753
+ center: bool = True,
754
+ normalized: bool = True,
755
+ pad_mode: str = "constant",
756
+ onesided: bool = True,
757
+ n_bands: int = None,
758
+ use_freq_weights: bool = True,
759
+ normalize_input: bool = False,
760
+ mult_add_mask: bool = False
761
+ ) -> None:
762
+ super().__init__(
763
+ stems=stems,
764
+ band_specs=band_specs,
765
+ fs=fs,
766
+ n_fft=n_fft,
767
+ win_length=win_length,
768
+ hop_length=hop_length,
769
+ window_fn=window_fn,
770
+ wkwargs=wkwargs,
771
+ power=power,
772
+ center=center,
773
+ normalized=normalized,
774
+ pad_mode=pad_mode,
775
+ onesided=onesided,
776
+ n_bands=n_bands,
777
+ )
778
+
779
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
780
+ stems=stems,
781
+ band_specs=self.band_specs,
782
+ in_channel=in_channel,
783
+ require_no_overlap=require_no_overlap,
784
+ require_no_gap=require_no_gap,
785
+ normalize_channel_independently=normalize_channel_independently,
786
+ treat_channel_as_feature=treat_channel_as_feature,
787
+ n_sqm_modules=n_sqm_modules,
788
+ emb_dim=emb_dim,
789
+ rnn_dim=rnn_dim,
790
+ bidirectional=bidirectional,
791
+ rnn_type=rnn_type,
792
+ mlp_dim=mlp_dim,
793
+ cond_dim=cond_dim,
794
+ hidden_activation=hidden_activation,
795
+ hidden_activation_kwargs=hidden_activation_kwargs,
796
+ complex_mask=complex_mask,
797
+ overlapping_band=self.overlapping_band,
798
+ freq_weights=self.freq_weights,
799
+ n_freq=n_fft // 2 + 1,
800
+ use_freq_weights=use_freq_weights,
801
+ mult_add_mask=mult_add_mask
802
+ )
803
+ class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
804
+ def __init__(
805
+ self,
806
+ in_channel: int,
807
+ stems: List[str],
808
+ band_specs: Union[str, List[Tuple[float, float]]],
809
+ kernel_norm_mlp_version: int = 1,
810
+ mask_kernel_freq: int = 3,
811
+ mask_kernel_time: int = 3,
812
+ conv_kernel_freq: int = 1,
813
+ conv_kernel_time: int = 1,
814
+ fs: int = 44100,
815
+ require_no_overlap: bool = False,
816
+ require_no_gap: bool = True,
817
+ normalize_channel_independently: bool = False,
818
+ treat_channel_as_feature: bool = True,
819
+ n_sqm_modules: int = 12,
820
+ emb_dim: int = 128,
821
+ rnn_dim: int = 256,
822
+ bidirectional: bool = True,
823
+ rnn_type: str = "LSTM",
824
+ mlp_dim: int = 512,
825
+ hidden_activation: str = "Tanh",
826
+ hidden_activation_kwargs: Optional[Dict] = None,
827
+ complex_mask: bool = True,
828
+ n_fft: int = 2048,
829
+ win_length: Optional[int] = 2048,
830
+ hop_length: int = 512,
831
+ window_fn: str = "hann_window",
832
+ wkwargs: Optional[Dict] = None,
833
+ power: Optional[int] = None,
834
+ center: bool = True,
835
+ normalized: bool = True,
836
+ pad_mode: str = "constant",
837
+ onesided: bool = True,
838
+ n_bands: int = None,
839
+ ) -> None:
840
+ super().__init__(
841
+ stems=stems,
842
+ band_specs=band_specs,
843
+ fs=fs,
844
+ n_fft=n_fft,
845
+ win_length=win_length,
846
+ hop_length=hop_length,
847
+ window_fn=window_fn,
848
+ wkwargs=wkwargs,
849
+ power=power,
850
+ center=center,
851
+ normalized=normalized,
852
+ pad_mode=pad_mode,
853
+ onesided=onesided,
854
+ n_bands=n_bands,
855
+ )
856
+
857
+ self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
858
+ stems=stems,
859
+ band_specs=self.band_specs,
860
+ in_channel=in_channel,
861
+ require_no_overlap=require_no_overlap,
862
+ require_no_gap=require_no_gap,
863
+ normalize_channel_independently=normalize_channel_independently,
864
+ treat_channel_as_feature=treat_channel_as_feature,
865
+ n_sqm_modules=n_sqm_modules,
866
+ emb_dim=emb_dim,
867
+ rnn_dim=rnn_dim,
868
+ bidirectional=bidirectional,
869
+ rnn_type=rnn_type,
870
+ mlp_dim=mlp_dim,
871
+ hidden_activation=hidden_activation,
872
+ hidden_activation_kwargs=hidden_activation_kwargs,
873
+ complex_mask=complex_mask,
874
+ overlapping_band=self.overlapping_band,
875
+ freq_weights=self.freq_weights,
876
+ n_freq=n_fft // 2 + 1,
877
+ mask_kernel_freq=mask_kernel_freq,
878
+ mask_kernel_time=mask_kernel_time,
879
+ conv_kernel_freq=conv_kernel_freq,
880
+ conv_kernel_time=conv_kernel_time,
881
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
882
+ )
models/bandit/core/utils/__init__.py ADDED
File without changes
models/bandit/core/utils/audio.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ from tqdm import tqdm
4
+ from typing import Callable, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ @torch.jit.script
13
+ def merge(
14
+ combined: torch.Tensor,
15
+ original_batch_size: int,
16
+ n_channel: int,
17
+ n_chunks: int,
18
+ chunk_size: int, ):
19
+ combined = torch.reshape(
20
+ combined,
21
+ (original_batch_size, n_chunks, n_channel, chunk_size)
22
+ )
23
+ combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
24
+ original_batch_size * n_channel,
25
+ chunk_size,
26
+ n_chunks
27
+ )
28
+
29
+ return combined
30
+
31
+
32
+ @torch.jit.script
33
+ def unfold(
34
+ padded_audio: torch.Tensor,
35
+ original_batch_size: int,
36
+ n_channel: int,
37
+ chunk_size: int,
38
+ hop_size: int
39
+ ) -> torch.Tensor:
40
+
41
+ unfolded_input = F.unfold(
42
+ padded_audio[:, :, None, :],
43
+ kernel_size=(1, chunk_size),
44
+ stride=(1, hop_size)
45
+ )
46
+
47
+ _, _, n_chunks = unfolded_input.shape
48
+ unfolded_input = unfolded_input.view(
49
+ original_batch_size,
50
+ n_channel,
51
+ chunk_size,
52
+ n_chunks
53
+ )
54
+ unfolded_input = torch.permute(
55
+ unfolded_input,
56
+ (0, 3, 1, 2)
57
+ ).reshape(
58
+ original_batch_size * n_chunks,
59
+ n_channel,
60
+ chunk_size
61
+ )
62
+
63
+ return unfolded_input
64
+
65
+
66
+ @torch.jit.script
67
+ # @torch.compile
68
+ def merge_chunks_all(
69
+ combined: torch.Tensor,
70
+ original_batch_size: int,
71
+ n_channel: int,
72
+ n_samples: int,
73
+ n_padded_samples: int,
74
+ n_chunks: int,
75
+ chunk_size: int,
76
+ hop_size: int,
77
+ edge_frame_pad_sizes: Tuple[int, int],
78
+ standard_window: torch.Tensor,
79
+ first_window: torch.Tensor,
80
+ last_window: torch.Tensor
81
+ ):
82
+ combined = merge(
83
+ combined,
84
+ original_batch_size,
85
+ n_channel,
86
+ n_chunks,
87
+ chunk_size
88
+ )
89
+
90
+ combined = combined * standard_window[:, None].to(combined.device)
91
+
92
+ combined = F.fold(
93
+ combined.to(torch.float32), output_size=(1, n_padded_samples),
94
+ kernel_size=(1, chunk_size),
95
+ stride=(1, hop_size)
96
+ )
97
+
98
+ combined = combined.view(
99
+ original_batch_size,
100
+ n_channel,
101
+ n_padded_samples
102
+ )
103
+
104
+ pad_front, pad_back = edge_frame_pad_sizes
105
+ combined = combined[..., pad_front:-pad_back]
106
+
107
+ combined = combined[..., :n_samples]
108
+
109
+ return combined
110
+
111
+ # @torch.jit.script
112
+
113
+
114
+ def merge_chunks_edge(
115
+ combined: torch.Tensor,
116
+ original_batch_size: int,
117
+ n_channel: int,
118
+ n_samples: int,
119
+ n_padded_samples: int,
120
+ n_chunks: int,
121
+ chunk_size: int,
122
+ hop_size: int,
123
+ edge_frame_pad_sizes: Tuple[int, int],
124
+ standard_window: torch.Tensor,
125
+ first_window: torch.Tensor,
126
+ last_window: torch.Tensor
127
+ ):
128
+ combined = merge(
129
+ combined,
130
+ original_batch_size,
131
+ n_channel,
132
+ n_chunks,
133
+ chunk_size
134
+ )
135
+
136
+ combined[..., 0] = combined[..., 0] * first_window
137
+ combined[..., -1] = combined[..., -1] * last_window
138
+ combined[..., 1:-1] = combined[...,
139
+ 1:-1] * standard_window[:, None]
140
+
141
+ combined = F.fold(
142
+ combined, output_size=(1, n_padded_samples),
143
+ kernel_size=(1, chunk_size),
144
+ stride=(1, hop_size)
145
+ )
146
+
147
+ combined = combined.view(
148
+ original_batch_size,
149
+ n_channel,
150
+ n_padded_samples
151
+ )
152
+
153
+ combined = combined[..., :n_samples]
154
+
155
+ return combined
156
+
157
+
158
+ class BaseFader(nn.Module):
159
+ def __init__(
160
+ self,
161
+ chunk_size_second: float,
162
+ hop_size_second: float,
163
+ fs: int,
164
+ fade_edge_frames: bool,
165
+ batch_size: int,
166
+ ) -> None:
167
+ super().__init__()
168
+
169
+ self.chunk_size = int(chunk_size_second * fs)
170
+ self.hop_size = int(hop_size_second * fs)
171
+ self.overlap_size = self.chunk_size - self.hop_size
172
+ self.fade_edge_frames = fade_edge_frames
173
+ self.batch_size = batch_size
174
+
175
+ # @torch.jit.script
176
+ def prepare(self, audio):
177
+
178
+ if self.fade_edge_frames:
179
+ audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
180
+
181
+ n_samples = audio.shape[-1]
182
+ n_chunks = int(
183
+ np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1
184
+ )
185
+
186
+ padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
187
+ pad_size = padded_size - n_samples
188
+
189
+ padded_audio = F.pad(audio, (0, pad_size))
190
+
191
+ return padded_audio, n_chunks
192
+
193
+ def forward(
194
+ self,
195
+ audio: torch.Tensor,
196
+ model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
197
+ ):
198
+
199
+ original_dtype = audio.dtype
200
+ original_device = audio.device
201
+
202
+ audio = audio.to("cpu")
203
+
204
+ original_batch_size, n_channel, n_samples = audio.shape
205
+ padded_audio, n_chunks = self.prepare(audio)
206
+ del audio
207
+ n_padded_samples = padded_audio.shape[-1]
208
+
209
+ if n_channel > 1:
210
+ padded_audio = padded_audio.view(
211
+ original_batch_size * n_channel, 1, n_padded_samples
212
+ )
213
+
214
+ unfolded_input = unfold(
215
+ padded_audio,
216
+ original_batch_size,
217
+ n_channel,
218
+ self.chunk_size, self.hop_size
219
+ )
220
+
221
+ n_total_chunks, n_channel, chunk_size = unfolded_input.shape
222
+
223
+ n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
224
+
225
+ chunks_in = [
226
+ unfolded_input[
227
+ b * self.batch_size:(b + 1) * self.batch_size, ...].clone()
228
+ for b in range(n_batch)
229
+ ]
230
+
231
+ all_chunks_out = defaultdict(
232
+ lambda: torch.zeros_like(
233
+ unfolded_input, device="cpu"
234
+ )
235
+ )
236
+
237
+ # for b, cin in enumerate(tqdm(chunks_in)):
238
+ for b, cin in enumerate(chunks_in):
239
+ if torch.allclose(cin, torch.tensor(0.0)):
240
+ del cin
241
+ continue
242
+
243
+ chunks_out = model_fn(cin.to(original_device))
244
+ del cin
245
+ for s, c in chunks_out.items():
246
+ all_chunks_out[s][b * self.batch_size:(b + 1) * self.batch_size,
247
+ ...] = c.cpu()
248
+ del chunks_out
249
+
250
+ del unfolded_input
251
+ del padded_audio
252
+
253
+ if self.fade_edge_frames:
254
+ fn = merge_chunks_all
255
+ else:
256
+ fn = merge_chunks_edge
257
+ outputs = {}
258
+
259
+ torch.cuda.empty_cache()
260
+
261
+ for s, c in all_chunks_out.items():
262
+ combined: torch.Tensor = fn(
263
+ c,
264
+ original_batch_size,
265
+ n_channel,
266
+ n_samples,
267
+ n_padded_samples,
268
+ n_chunks,
269
+ self.chunk_size,
270
+ self.hop_size,
271
+ self.edge_frame_pad_sizes,
272
+ self.standard_window,
273
+ self.__dict__.get("first_window", self.standard_window),
274
+ self.__dict__.get("last_window", self.standard_window)
275
+ )
276
+
277
+ outputs[s] = combined.to(
278
+ dtype=original_dtype,
279
+ device=original_device
280
+ )
281
+
282
+ return {
283
+ "audio": outputs
284
+ }
285
+ #
286
+ # def old_forward(
287
+ # self,
288
+ # audio: torch.Tensor,
289
+ # model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
290
+ # ):
291
+ #
292
+ # n_samples = audio.shape[-1]
293
+ # original_batch_size = audio.shape[0]
294
+ #
295
+ # padded_audio, n_chunks = self.prepare(audio)
296
+ #
297
+ # ndim = padded_audio.ndim
298
+ # broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size]
299
+ #
300
+ # outputs = defaultdict(
301
+ # lambda: torch.zeros_like(
302
+ # padded_audio, device=audio.device, dtype=torch.float64
303
+ # )
304
+ # )
305
+ #
306
+ # all_chunks_out = []
307
+ # len_chunks_in = []
308
+ #
309
+ # batch_size_ = int(self.batch_size // original_batch_size)
310
+ # for b in range(int(np.ceil(n_chunks / batch_size_))):
311
+ # chunks_in = []
312
+ # for j in range(batch_size_):
313
+ # i = b * batch_size_ + j
314
+ # if i == n_chunks:
315
+ # break
316
+ #
317
+ # start = i * hop_size
318
+ # end = start + self.chunk_size
319
+ # chunk_in = padded_audio[..., start:end]
320
+ # chunks_in.append(chunk_in)
321
+ #
322
+ # chunks_in = torch.concat(chunks_in, dim=0)
323
+ # chunks_out = model_fn(chunks_in)
324
+ # all_chunks_out.append(chunks_out)
325
+ # len_chunks_in.append(len(chunks_in))
326
+ #
327
+ # for b, (chunks_out, lci) in enumerate(
328
+ # zip(all_chunks_out, len_chunks_in)
329
+ # ):
330
+ # for stem in chunks_out:
331
+ # for j in range(lci // original_batch_size):
332
+ # i = b * batch_size_ + j
333
+ #
334
+ # if self.fade_edge_frames:
335
+ # window = self.standard_window
336
+ # else:
337
+ # if i == 0:
338
+ # window = self.first_window
339
+ # elif i == n_chunks - 1:
340
+ # window = self.last_window
341
+ # else:
342
+ # window = self.standard_window
343
+ #
344
+ # start = i * hop_size
345
+ # end = start + self.chunk_size
346
+ #
347
+ # chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size,
348
+ # ...]
349
+ # contrib = window.view(*broadcaster) * chunk_out
350
+ # outputs[stem][..., start:end] = (
351
+ # outputs[stem][..., start:end] + contrib
352
+ # )
353
+ #
354
+ # if self.fade_edge_frames:
355
+ # pad_front, pad_back = self.edge_frame_pad_sizes
356
+ # outputs = {k: v[..., pad_front:-pad_back] for k, v in
357
+ # outputs.items()}
358
+ #
359
+ # outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in
360
+ # outputs.items()}
361
+ #
362
+ # return {
363
+ # "audio": outputs
364
+ # }
365
+
366
+
367
+ class LinearFader(BaseFader):
368
+ def __init__(
369
+ self,
370
+ chunk_size_second: float,
371
+ hop_size_second: float,
372
+ fs: int,
373
+ fade_edge_frames: bool = False,
374
+ batch_size: int = 1,
375
+ ) -> None:
376
+
377
+ assert hop_size_second >= chunk_size_second / 2
378
+
379
+ super().__init__(
380
+ chunk_size_second=chunk_size_second,
381
+ hop_size_second=hop_size_second,
382
+ fs=fs,
383
+ fade_edge_frames=fade_edge_frames,
384
+ batch_size=batch_size,
385
+ )
386
+
387
+ in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
388
+ out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
389
+ center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
390
+ inout_ones = torch.ones(self.overlap_size)
391
+
392
+ # using nn.Parameters allows lightning to take care of devices for us
393
+ self.register_buffer(
394
+ "standard_window",
395
+ torch.concat([in_fade, center_ones, out_fade])
396
+ )
397
+
398
+ self.fade_edge_frames = fade_edge_frames
399
+ self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
400
+
401
+ if not self.fade_edge_frames:
402
+ self.first_window = nn.Parameter(
403
+ torch.concat([inout_ones, center_ones, out_fade]),
404
+ requires_grad=False
405
+ )
406
+ self.last_window = nn.Parameter(
407
+ torch.concat([in_fade, center_ones, inout_ones]),
408
+ requires_grad=False
409
+ )
410
+
411
+
412
+ class OverlapAddFader(BaseFader):
413
+ def __init__(
414
+ self,
415
+ window_type: str,
416
+ chunk_size_second: float,
417
+ hop_size_second: float,
418
+ fs: int,
419
+ batch_size: int = 1,
420
+ ) -> None:
421
+ assert (chunk_size_second / hop_size_second) % 2 == 0
422
+ assert int(chunk_size_second * fs) % 2 == 0
423
+
424
+ super().__init__(
425
+ chunk_size_second=chunk_size_second,
426
+ hop_size_second=hop_size_second,
427
+ fs=fs,
428
+ fade_edge_frames=True,
429
+ batch_size=batch_size,
430
+ )
431
+
432
+ self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
433
+ # print(f"hop multiplier: {self.hop_multiplier}")
434
+
435
+ self.edge_frame_pad_sizes = (
436
+ 2 * self.overlap_size,
437
+ 2 * self.overlap_size
438
+ )
439
+
440
+ self.register_buffer(
441
+ "standard_window", torch.windows.__dict__[window_type](
442
+ self.chunk_size, sym=False, # dtype=torch.float64
443
+ ) / self.hop_multiplier
444
+ )
445
+
446
+
447
+ if __name__ == "__main__":
448
+ import torchaudio as ta
449
+ fs = 44100
450
+ ola = OverlapAddFader(
451
+ "hann",
452
+ 6.0,
453
+ 1.0,
454
+ fs,
455
+ batch_size=16
456
+ )
457
+ audio_, _ = ta.load(
458
+ "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too "
459
+ "Much/vocals.wav"
460
+ )
461
+ audio_ = audio_[None, ...]
462
+ out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
463
+ print(torch.allclose(out, audio_))
models/bandit/model_from_config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os.path
3
+ import torch
4
+
5
+ code_path = os.path.dirname(os.path.abspath(__file__)) + '/'
6
+ sys.path.append(code_path)
7
+
8
+ import yaml
9
+ from ml_collections import ConfigDict
10
+
11
+ torch.set_float32_matmul_precision("medium")
12
+
13
+
14
+ def get_model(
15
+ config_path,
16
+ weights_path,
17
+ device,
18
+ ):
19
+ from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
20
+
21
+ f = open(config_path)
22
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
23
+ f.close()
24
+
25
+ model = MultiMaskMultiSourceBandSplitRNNSimple(
26
+ **config.model
27
+ )
28
+ d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt')
29
+ model.load_state_dict(d)
30
+ model.to(device)
31
+ return model, config
models/bs_roformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from models.bs_roformer.bs_roformer import BSRoformer
2
+ from models.bs_roformer.mel_band_roformer import MelBandRoformer
models/bs_roformer/attend.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from packaging import version
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ from torch import nn, einsum
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, reduce
10
+
11
+ # constants
12
+
13
+ FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14
+
15
+ # helpers
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+ def default(v, d):
21
+ return v if exists(v) else d
22
+
23
+ def once(fn):
24
+ called = False
25
+ @wraps(fn)
26
+ def inner(x):
27
+ nonlocal called
28
+ if called:
29
+ return
30
+ called = True
31
+ return fn(x)
32
+ return inner
33
+
34
+ print_once = once(print)
35
+
36
+ # main class
37
+
38
+ class Attend(nn.Module):
39
+ def __init__(
40
+ self,
41
+ dropout = 0.,
42
+ flash = False,
43
+ scale = None
44
+ ):
45
+ super().__init__()
46
+ self.scale = scale
47
+ self.dropout = dropout
48
+ self.attn_dropout = nn.Dropout(dropout)
49
+
50
+ self.flash = flash
51
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
52
+
53
+ # determine efficient attention configs for cuda and cpu
54
+
55
+ self.cpu_config = FlashAttentionConfig(True, True, True)
56
+ self.cuda_config = None
57
+
58
+ if not torch.cuda.is_available() or not flash:
59
+ return
60
+
61
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
62
+
63
+ if device_properties.major == 8 and device_properties.minor == 0:
64
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
65
+ self.cuda_config = FlashAttentionConfig(True, False, False)
66
+ else:
67
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
68
+ self.cuda_config = FlashAttentionConfig(False, True, True)
69
+
70
+ def flash_attn(self, q, k, v):
71
+ _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
72
+
73
+ if exists(self.scale):
74
+ default_scale = q.shape[-1] ** -0.5
75
+ q = q * (self.scale / default_scale)
76
+
77
+ # Check if there is a compatible device for flash attention
78
+
79
+ config = self.cuda_config if is_cuda else self.cpu_config
80
+
81
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
82
+
83
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
84
+ out = F.scaled_dot_product_attention(
85
+ q, k, v,
86
+ dropout_p = self.dropout if self.training else 0.
87
+ )
88
+
89
+ return out
90
+
91
+ def forward(self, q, k, v):
92
+ """
93
+ einstein notation
94
+ b - batch
95
+ h - heads
96
+ n, i, j - sequence length (base sequence length, source, target)
97
+ d - feature dimension
98
+ """
99
+
100
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
101
+
102
+ scale = default(self.scale, q.shape[-1] ** -0.5)
103
+
104
+ if self.flash:
105
+ return self.flash_attn(q, k, v)
106
+
107
+ # similarity
108
+
109
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
110
+
111
+ # attention
112
+
113
+ attn = sim.softmax(dim=-1)
114
+ attn = self.attn_dropout(attn)
115
+
116
+ # aggregate values
117
+
118
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
119
+
120
+ return out
models/bs_roformer/bs_roformer.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from models.bs_roformer.attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack
16
+ from einops.layers.torch import Rearrange
17
+
18
+ # helper functions
19
+
20
+ def exists(val):
21
+ return val is not None
22
+
23
+
24
+ def default(v, d):
25
+ return v if exists(v) else d
26
+
27
+
28
+ def pack_one(t, pattern):
29
+ return pack([t], pattern)
30
+
31
+
32
+ def unpack_one(t, ps, pattern):
33
+ return unpack(t, ps, pattern)[0]
34
+
35
+
36
+ # norm
37
+
38
+ def l2norm(t):
39
+ return F.normalize(t, dim = -1, p = 2)
40
+
41
+
42
+ class RMSNorm(Module):
43
+ def __init__(self, dim):
44
+ super().__init__()
45
+ self.scale = dim ** 0.5
46
+ self.gamma = nn.Parameter(torch.ones(dim))
47
+
48
+ def forward(self, x):
49
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
50
+
51
+
52
+ # attention
53
+
54
+ class FeedForward(Module):
55
+ def __init__(
56
+ self,
57
+ dim,
58
+ mult=4,
59
+ dropout=0.
60
+ ):
61
+ super().__init__()
62
+ dim_inner = int(dim * mult)
63
+ self.net = nn.Sequential(
64
+ RMSNorm(dim),
65
+ nn.Linear(dim, dim_inner),
66
+ nn.GELU(),
67
+ nn.Dropout(dropout),
68
+ nn.Linear(dim_inner, dim),
69
+ nn.Dropout(dropout)
70
+ )
71
+
72
+ def forward(self, x):
73
+ return self.net(x)
74
+
75
+
76
+ class Attention(Module):
77
+ def __init__(
78
+ self,
79
+ dim,
80
+ heads=8,
81
+ dim_head=64,
82
+ dropout=0.,
83
+ rotary_embed=None,
84
+ flash=True
85
+ ):
86
+ super().__init__()
87
+ self.heads = heads
88
+ self.scale = dim_head ** -0.5
89
+ dim_inner = heads * dim_head
90
+
91
+ self.rotary_embed = rotary_embed
92
+
93
+ self.attend = Attend(flash=flash, dropout=dropout)
94
+
95
+ self.norm = RMSNorm(dim)
96
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
97
+
98
+ self.to_gates = nn.Linear(dim, heads)
99
+
100
+ self.to_out = nn.Sequential(
101
+ nn.Linear(dim_inner, dim, bias=False),
102
+ nn.Dropout(dropout)
103
+ )
104
+
105
+ def forward(self, x):
106
+ x = self.norm(x)
107
+
108
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
109
+
110
+ if exists(self.rotary_embed):
111
+ q = self.rotary_embed.rotate_queries_or_keys(q)
112
+ k = self.rotary_embed.rotate_queries_or_keys(k)
113
+
114
+ out = self.attend(q, k, v)
115
+
116
+ gates = self.to_gates(x)
117
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
118
+
119
+ out = rearrange(out, 'b h n d -> b n (h d)')
120
+ return self.to_out(out)
121
+
122
+
123
+ class LinearAttention(Module):
124
+ """
125
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
126
+ """
127
+
128
+ @beartype
129
+ def __init__(
130
+ self,
131
+ *,
132
+ dim,
133
+ dim_head=32,
134
+ heads=8,
135
+ scale=8,
136
+ flash=False,
137
+ dropout=0.
138
+ ):
139
+ super().__init__()
140
+ dim_inner = dim_head * heads
141
+ self.norm = RMSNorm(dim)
142
+
143
+ self.to_qkv = nn.Sequential(
144
+ nn.Linear(dim, dim_inner * 3, bias=False),
145
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
146
+ )
147
+
148
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
+
150
+ self.attend = Attend(
151
+ scale=scale,
152
+ dropout=dropout,
153
+ flash=flash
154
+ )
155
+
156
+ self.to_out = nn.Sequential(
157
+ Rearrange('b h d n -> b n (h d)'),
158
+ nn.Linear(dim_inner, dim, bias=False)
159
+ )
160
+
161
+ def forward(
162
+ self,
163
+ x
164
+ ):
165
+ x = self.norm(x)
166
+
167
+ q, k, v = self.to_qkv(x)
168
+
169
+ q, k = map(l2norm, (q, k))
170
+ q = q * self.temperature.exp()
171
+
172
+ out = self.attend(q, k, v)
173
+
174
+ return self.to_out(out)
175
+
176
+
177
+ class Transformer(Module):
178
+ def __init__(
179
+ self,
180
+ *,
181
+ dim,
182
+ depth,
183
+ dim_head=64,
184
+ heads=8,
185
+ attn_dropout=0.,
186
+ ff_dropout=0.,
187
+ ff_mult=4,
188
+ norm_output=True,
189
+ rotary_embed=None,
190
+ flash_attn=True,
191
+ linear_attn=False
192
+ ):
193
+ super().__init__()
194
+ self.layers = ModuleList([])
195
+
196
+ for _ in range(depth):
197
+ if linear_attn:
198
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
199
+ else:
200
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
201
+ rotary_embed=rotary_embed, flash=flash_attn)
202
+
203
+ self.layers.append(ModuleList([
204
+ attn,
205
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
206
+ ]))
207
+
208
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
209
+
210
+ def forward(self, x):
211
+
212
+ for attn, ff in self.layers:
213
+ x = attn(x) + x
214
+ x = ff(x) + x
215
+
216
+ return self.norm(x)
217
+
218
+
219
+ # bandsplit module
220
+
221
+ class BandSplit(Module):
222
+ @beartype
223
+ def __init__(
224
+ self,
225
+ dim,
226
+ dim_inputs: Tuple[int, ...]
227
+ ):
228
+ super().__init__()
229
+ self.dim_inputs = dim_inputs
230
+ self.to_features = ModuleList([])
231
+
232
+ for dim_in in dim_inputs:
233
+ net = nn.Sequential(
234
+ RMSNorm(dim_in),
235
+ nn.Linear(dim_in, dim)
236
+ )
237
+
238
+ self.to_features.append(net)
239
+
240
+ def forward(self, x):
241
+ x = x.split(self.dim_inputs, dim=-1)
242
+
243
+ outs = []
244
+ for split_input, to_feature in zip(x, self.to_features):
245
+ split_output = to_feature(split_input)
246
+ outs.append(split_output)
247
+
248
+ return torch.stack(outs, dim=-2)
249
+
250
+
251
+ def MLP(
252
+ dim_in,
253
+ dim_out,
254
+ dim_hidden=None,
255
+ depth=1,
256
+ activation=nn.Tanh
257
+ ):
258
+ dim_hidden = default(dim_hidden, dim_in)
259
+
260
+ net = []
261
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
262
+
263
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
264
+ is_last = ind == (len(dims) - 2)
265
+
266
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
267
+
268
+ if is_last:
269
+ continue
270
+
271
+ net.append(activation())
272
+
273
+ return nn.Sequential(*net)
274
+
275
+
276
+ class MaskEstimator(Module):
277
+ @beartype
278
+ def __init__(
279
+ self,
280
+ dim,
281
+ dim_inputs: Tuple[int, ...],
282
+ depth,
283
+ mlp_expansion_factor=4
284
+ ):
285
+ super().__init__()
286
+ self.dim_inputs = dim_inputs
287
+ self.to_freqs = ModuleList([])
288
+ dim_hidden = dim * mlp_expansion_factor
289
+
290
+ for dim_in in dim_inputs:
291
+ net = []
292
+
293
+ mlp = nn.Sequential(
294
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
295
+ nn.GLU(dim=-1)
296
+ )
297
+
298
+ self.to_freqs.append(mlp)
299
+
300
+ def forward(self, x):
301
+ x = x.unbind(dim=-2)
302
+
303
+ outs = []
304
+
305
+ for band_features, mlp in zip(x, self.to_freqs):
306
+ freq_out = mlp(band_features)
307
+ outs.append(freq_out)
308
+
309
+ return torch.cat(outs, dim=-1)
310
+
311
+
312
+ # main class
313
+
314
+ DEFAULT_FREQS_PER_BANDS = (
315
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
316
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
317
+ 2, 2, 2, 2,
318
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
319
+ 12, 12, 12, 12, 12, 12, 12, 12,
320
+ 24, 24, 24, 24, 24, 24, 24, 24,
321
+ 48, 48, 48, 48, 48, 48, 48, 48,
322
+ 128, 129,
323
+ )
324
+
325
+
326
+ class BSRoformer(Module):
327
+
328
+ @beartype
329
+ def __init__(
330
+ self,
331
+ dim,
332
+ *,
333
+ depth,
334
+ stereo=False,
335
+ num_stems=1,
336
+ time_transformer_depth=2,
337
+ freq_transformer_depth=2,
338
+ linear_transformer_depth=0,
339
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
340
+ # in the paper, they divide into ~60 bands, test with 1 for starters
341
+ dim_head=64,
342
+ heads=8,
343
+ attn_dropout=0.,
344
+ ff_dropout=0.,
345
+ flash_attn=True,
346
+ dim_freqs_in=1025,
347
+ stft_n_fft=2048,
348
+ stft_hop_length=512,
349
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
350
+ stft_win_length=2048,
351
+ stft_normalized=False,
352
+ stft_window_fn: Optional[Callable] = None,
353
+ mask_estimator_depth=2,
354
+ multi_stft_resolution_loss_weight=1.,
355
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
356
+ multi_stft_hop_size=147,
357
+ multi_stft_normalized=False,
358
+ multi_stft_window_fn: Callable = torch.hann_window
359
+ ):
360
+ super().__init__()
361
+
362
+ self.stereo = stereo
363
+ self.audio_channels = 2 if stereo else 1
364
+ self.num_stems = num_stems
365
+
366
+ self.layers = ModuleList([])
367
+
368
+ transformer_kwargs = dict(
369
+ dim=dim,
370
+ heads=heads,
371
+ dim_head=dim_head,
372
+ attn_dropout=attn_dropout,
373
+ ff_dropout=ff_dropout,
374
+ flash_attn=flash_attn,
375
+ norm_output=False
376
+ )
377
+
378
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
379
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
380
+
381
+ for _ in range(depth):
382
+ tran_modules = []
383
+ if linear_transformer_depth > 0:
384
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
385
+ tran_modules.append(
386
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
387
+ )
388
+ tran_modules.append(
389
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
390
+ )
391
+ self.layers.append(nn.ModuleList(tran_modules))
392
+
393
+ self.final_norm = RMSNorm(dim)
394
+
395
+ self.stft_kwargs = dict(
396
+ n_fft=stft_n_fft,
397
+ hop_length=stft_hop_length,
398
+ win_length=stft_win_length,
399
+ normalized=stft_normalized
400
+ )
401
+
402
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
403
+
404
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
405
+
406
+ assert len(freqs_per_bands) > 1
407
+ assert sum(
408
+ freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
409
+
410
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
411
+
412
+ self.band_split = BandSplit(
413
+ dim=dim,
414
+ dim_inputs=freqs_per_bands_with_complex
415
+ )
416
+
417
+ self.mask_estimators = nn.ModuleList([])
418
+
419
+ for _ in range(num_stems):
420
+ mask_estimator = MaskEstimator(
421
+ dim=dim,
422
+ dim_inputs=freqs_per_bands_with_complex,
423
+ depth=mask_estimator_depth
424
+ )
425
+
426
+ self.mask_estimators.append(mask_estimator)
427
+
428
+ # for the multi-resolution stft loss
429
+
430
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
431
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
432
+ self.multi_stft_n_fft = stft_n_fft
433
+ self.multi_stft_window_fn = multi_stft_window_fn
434
+
435
+ self.multi_stft_kwargs = dict(
436
+ hop_length=multi_stft_hop_size,
437
+ normalized=multi_stft_normalized
438
+ )
439
+
440
+ def forward(
441
+ self,
442
+ raw_audio,
443
+ target=None,
444
+ return_loss_breakdown=False
445
+ ):
446
+ """
447
+ einops
448
+
449
+ b - batch
450
+ f - freq
451
+ t - time
452
+ s - audio channel (1 for mono, 2 for stereo)
453
+ n - number of 'stems'
454
+ c - complex (2)
455
+ d - feature dimension
456
+ """
457
+
458
+ device = raw_audio.device
459
+
460
+ if raw_audio.ndim == 2:
461
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
462
+
463
+ channels = raw_audio.shape[1]
464
+ assert (not self.stereo and channels == 1) or (
465
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
466
+
467
+ # to stft
468
+
469
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
470
+
471
+ stft_window = self.stft_window_fn(device=device)
472
+
473
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
474
+ stft_repr = torch.view_as_real(stft_repr)
475
+
476
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
477
+ stft_repr = rearrange(stft_repr,
478
+ 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
479
+
480
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
481
+
482
+ x = self.band_split(x)
483
+
484
+ # axial / hierarchical attention
485
+
486
+ for transformer_block in self.layers:
487
+
488
+ if len(transformer_block) == 3:
489
+ linear_transformer, time_transformer, freq_transformer = transformer_block
490
+
491
+ x, ft_ps = pack([x], 'b * d')
492
+ x = linear_transformer(x)
493
+ x, = unpack(x, ft_ps, 'b * d')
494
+ else:
495
+ time_transformer, freq_transformer = transformer_block
496
+
497
+ x = rearrange(x, 'b t f d -> b f t d')
498
+ x, ps = pack([x], '* t d')
499
+
500
+ x = time_transformer(x)
501
+
502
+ x, = unpack(x, ps, '* t d')
503
+ x = rearrange(x, 'b f t d -> b t f d')
504
+ x, ps = pack([x], '* f d')
505
+
506
+ x = freq_transformer(x)
507
+
508
+ x, = unpack(x, ps, '* f d')
509
+
510
+ x = self.final_norm(x)
511
+
512
+ num_stems = len(self.mask_estimators)
513
+
514
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
515
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
516
+
517
+ # modulate frequency representation
518
+
519
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
520
+
521
+ # complex number multiplication
522
+
523
+ stft_repr = torch.view_as_complex(stft_repr)
524
+ mask = torch.view_as_complex(mask)
525
+
526
+ stft_repr = stft_repr * mask
527
+
528
+ # istft
529
+
530
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
531
+
532
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
533
+
534
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
535
+
536
+ if num_stems == 1:
537
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
538
+
539
+ # if a target is passed in, calculate loss for learning
540
+
541
+ if not exists(target):
542
+ return recon_audio
543
+
544
+ if self.num_stems > 1:
545
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
546
+
547
+ if target.ndim == 2:
548
+ target = rearrange(target, '... t -> ... 1 t')
549
+
550
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
551
+
552
+ loss = F.l1_loss(recon_audio, target)
553
+
554
+ multi_stft_resolution_loss = 0.
555
+
556
+ for window_size in self.multi_stft_resolutions_window_sizes:
557
+ res_stft_kwargs = dict(
558
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
559
+ win_length=window_size,
560
+ return_complex=True,
561
+ window=self.multi_stft_window_fn(window_size, device=device),
562
+ **self.multi_stft_kwargs,
563
+ )
564
+
565
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
566
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
567
+
568
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
569
+
570
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
571
+
572
+ total_loss = loss + weighted_multi_resolution_loss
573
+
574
+ if not return_loss_breakdown:
575
+ return total_loss
576
+
577
+ return total_loss, (loss, multi_stft_resolution_loss)
models/bs_roformer/mel_band_roformer.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from models.bs_roformer.attend import Attend
9
+
10
+ from beartype.typing import Tuple, Optional, List, Callable
11
+ from beartype import beartype
12
+
13
+ from rotary_embedding_torch import RotaryEmbedding
14
+
15
+ from einops import rearrange, pack, unpack, reduce, repeat
16
+ from einops.layers.torch import Rearrange
17
+
18
+ from librosa import filters
19
+
20
+
21
+ # helper functions
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def default(v, d):
28
+ return v if exists(v) else d
29
+
30
+
31
+ def pack_one(t, pattern):
32
+ return pack([t], pattern)
33
+
34
+
35
+ def unpack_one(t, ps, pattern):
36
+ return unpack(t, ps, pattern)[0]
37
+
38
+
39
+ def pad_at_dim(t, pad, dim=-1, value=0.):
40
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
41
+ zeros = ((0, 0) * dims_from_right)
42
+ return F.pad(t, (*zeros, *pad), value=value)
43
+
44
+
45
+ def l2norm(t):
46
+ return F.normalize(t, dim=-1, p=2)
47
+
48
+
49
+ # norm
50
+
51
+ class RMSNorm(Module):
52
+ def __init__(self, dim):
53
+ super().__init__()
54
+ self.scale = dim ** 0.5
55
+ self.gamma = nn.Parameter(torch.ones(dim))
56
+
57
+ def forward(self, x):
58
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
59
+
60
+
61
+ # attention
62
+
63
+ class FeedForward(Module):
64
+ def __init__(
65
+ self,
66
+ dim,
67
+ mult=4,
68
+ dropout=0.
69
+ ):
70
+ super().__init__()
71
+ dim_inner = int(dim * mult)
72
+ self.net = nn.Sequential(
73
+ RMSNorm(dim),
74
+ nn.Linear(dim, dim_inner),
75
+ nn.GELU(),
76
+ nn.Dropout(dropout),
77
+ nn.Linear(dim_inner, dim),
78
+ nn.Dropout(dropout)
79
+ )
80
+
81
+ def forward(self, x):
82
+ return self.net(x)
83
+
84
+
85
+ class Attention(Module):
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ heads=8,
90
+ dim_head=64,
91
+ dropout=0.,
92
+ rotary_embed=None,
93
+ flash=True
94
+ ):
95
+ super().__init__()
96
+ self.heads = heads
97
+ self.scale = dim_head ** -0.5
98
+ dim_inner = heads * dim_head
99
+
100
+ self.rotary_embed = rotary_embed
101
+
102
+ self.attend = Attend(flash=flash, dropout=dropout)
103
+
104
+ self.norm = RMSNorm(dim)
105
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
106
+
107
+ self.to_gates = nn.Linear(dim, heads)
108
+
109
+ self.to_out = nn.Sequential(
110
+ nn.Linear(dim_inner, dim, bias=False),
111
+ nn.Dropout(dropout)
112
+ )
113
+
114
+ def forward(self, x):
115
+ x = self.norm(x)
116
+
117
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
118
+
119
+ if exists(self.rotary_embed):
120
+ q = self.rotary_embed.rotate_queries_or_keys(q)
121
+ k = self.rotary_embed.rotate_queries_or_keys(k)
122
+
123
+ out = self.attend(q, k, v)
124
+
125
+ gates = self.to_gates(x)
126
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
127
+
128
+ out = rearrange(out, 'b h n d -> b n (h d)')
129
+ return self.to_out(out)
130
+
131
+
132
+ class LinearAttention(Module):
133
+ """
134
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
135
+ """
136
+
137
+ @beartype
138
+ def __init__(
139
+ self,
140
+ *,
141
+ dim,
142
+ dim_head=32,
143
+ heads=8,
144
+ scale=8,
145
+ flash=False,
146
+ dropout=0.
147
+ ):
148
+ super().__init__()
149
+ dim_inner = dim_head * heads
150
+ self.norm = RMSNorm(dim)
151
+
152
+ self.to_qkv = nn.Sequential(
153
+ nn.Linear(dim, dim_inner * 3, bias=False),
154
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
155
+ )
156
+
157
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
158
+
159
+ self.attend = Attend(
160
+ scale=scale,
161
+ dropout=dropout,
162
+ flash=flash
163
+ )
164
+
165
+ self.to_out = nn.Sequential(
166
+ Rearrange('b h d n -> b n (h d)'),
167
+ nn.Linear(dim_inner, dim, bias=False)
168
+ )
169
+
170
+ def forward(
171
+ self,
172
+ x
173
+ ):
174
+ x = self.norm(x)
175
+
176
+ q, k, v = self.to_qkv(x)
177
+
178
+ q, k = map(l2norm, (q, k))
179
+ q = q * self.temperature.exp()
180
+
181
+ out = self.attend(q, k, v)
182
+
183
+ return self.to_out(out)
184
+
185
+
186
+ class Transformer(Module):
187
+ def __init__(
188
+ self,
189
+ *,
190
+ dim,
191
+ depth,
192
+ dim_head=64,
193
+ heads=8,
194
+ attn_dropout=0.,
195
+ ff_dropout=0.,
196
+ ff_mult=4,
197
+ norm_output=True,
198
+ rotary_embed=None,
199
+ flash_attn=True,
200
+ linear_attn=False
201
+ ):
202
+ super().__init__()
203
+ self.layers = ModuleList([])
204
+
205
+ for _ in range(depth):
206
+ if linear_attn:
207
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
208
+ else:
209
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
210
+ rotary_embed=rotary_embed, flash=flash_attn)
211
+
212
+ self.layers.append(ModuleList([
213
+ attn,
214
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
215
+ ]))
216
+
217
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
218
+
219
+ def forward(self, x):
220
+
221
+ for attn, ff in self.layers:
222
+ x = attn(x) + x
223
+ x = ff(x) + x
224
+
225
+ return self.norm(x)
226
+
227
+
228
+ # bandsplit module
229
+
230
+ class BandSplit(Module):
231
+ @beartype
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ dim_inputs: Tuple[int, ...]
236
+ ):
237
+ super().__init__()
238
+ self.dim_inputs = dim_inputs
239
+ self.to_features = ModuleList([])
240
+
241
+ for dim_in in dim_inputs:
242
+ net = nn.Sequential(
243
+ RMSNorm(dim_in),
244
+ nn.Linear(dim_in, dim)
245
+ )
246
+
247
+ self.to_features.append(net)
248
+
249
+ def forward(self, x):
250
+ x = x.split(self.dim_inputs, dim=-1)
251
+
252
+ outs = []
253
+ for split_input, to_feature in zip(x, self.to_features):
254
+ split_output = to_feature(split_input)
255
+ outs.append(split_output)
256
+
257
+ return torch.stack(outs, dim=-2)
258
+
259
+
260
+ def MLP(
261
+ dim_in,
262
+ dim_out,
263
+ dim_hidden=None,
264
+ depth=1,
265
+ activation=nn.Tanh
266
+ ):
267
+ dim_hidden = default(dim_hidden, dim_in)
268
+
269
+ net = []
270
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
271
+
272
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
273
+ is_last = ind == (len(dims) - 2)
274
+
275
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
276
+
277
+ if is_last:
278
+ continue
279
+
280
+ net.append(activation())
281
+
282
+ return nn.Sequential(*net)
283
+
284
+
285
+ class MaskEstimator(Module):
286
+ @beartype
287
+ def __init__(
288
+ self,
289
+ dim,
290
+ dim_inputs: Tuple[int, ...],
291
+ depth,
292
+ mlp_expansion_factor=4
293
+ ):
294
+ super().__init__()
295
+ self.dim_inputs = dim_inputs
296
+ self.to_freqs = ModuleList([])
297
+ dim_hidden = dim * mlp_expansion_factor
298
+
299
+ for dim_in in dim_inputs:
300
+ net = []
301
+
302
+ mlp = nn.Sequential(
303
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
304
+ nn.GLU(dim=-1)
305
+ )
306
+
307
+ self.to_freqs.append(mlp)
308
+
309
+ def forward(self, x):
310
+ x = x.unbind(dim=-2)
311
+
312
+ outs = []
313
+
314
+ for band_features, mlp in zip(x, self.to_freqs):
315
+ freq_out = mlp(band_features)
316
+ outs.append(freq_out)
317
+
318
+ return torch.cat(outs, dim=-1)
319
+
320
+
321
+ # main class
322
+
323
+ class MelBandRoformer(Module):
324
+
325
+ @beartype
326
+ def __init__(
327
+ self,
328
+ dim,
329
+ *,
330
+ depth,
331
+ stereo=False,
332
+ num_stems=1,
333
+ time_transformer_depth=2,
334
+ freq_transformer_depth=2,
335
+ linear_transformer_depth=0,
336
+ num_bands=60,
337
+ dim_head=64,
338
+ heads=8,
339
+ attn_dropout=0.1,
340
+ ff_dropout=0.1,
341
+ flash_attn=True,
342
+ dim_freqs_in=1025,
343
+ sample_rate=44100, # needed for mel filter bank from librosa
344
+ stft_n_fft=2048,
345
+ stft_hop_length=512,
346
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
347
+ stft_win_length=2048,
348
+ stft_normalized=False,
349
+ stft_window_fn: Optional[Callable] = None,
350
+ mask_estimator_depth=1,
351
+ multi_stft_resolution_loss_weight=1.,
352
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
353
+ multi_stft_hop_size=147,
354
+ multi_stft_normalized=False,
355
+ multi_stft_window_fn: Callable = torch.hann_window,
356
+ match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
357
+ ):
358
+ super().__init__()
359
+
360
+ self.stereo = stereo
361
+ self.audio_channels = 2 if stereo else 1
362
+ self.num_stems = num_stems
363
+
364
+ self.layers = ModuleList([])
365
+
366
+ transformer_kwargs = dict(
367
+ dim=dim,
368
+ heads=heads,
369
+ dim_head=dim_head,
370
+ attn_dropout=attn_dropout,
371
+ ff_dropout=ff_dropout,
372
+ flash_attn=flash_attn
373
+ )
374
+
375
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
376
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
377
+
378
+ for _ in range(depth):
379
+ tran_modules = []
380
+ if linear_transformer_depth > 0:
381
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
382
+ tran_modules.append(
383
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
384
+ )
385
+ tran_modules.append(
386
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
387
+ )
388
+ self.layers.append(nn.ModuleList(tran_modules))
389
+
390
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
391
+
392
+ self.stft_kwargs = dict(
393
+ n_fft=stft_n_fft,
394
+ hop_length=stft_hop_length,
395
+ win_length=stft_win_length,
396
+ normalized=stft_normalized
397
+ )
398
+
399
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
400
+
401
+ # create mel filter bank
402
+ # with librosa.filters.mel as in section 2 of paper
403
+
404
+ mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
405
+
406
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
407
+
408
+ # for some reason, it doesn't include the first freq? just force a value for now
409
+
410
+ mel_filter_bank[0][0] = 1.
411
+
412
+ # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
413
+ # so let's force a positive value
414
+
415
+ mel_filter_bank[-1, -1] = 1.
416
+
417
+ # binary as in paper (then estimated masks are averaged for overlapping regions)
418
+
419
+ freqs_per_band = mel_filter_bank > 0
420
+ assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
421
+
422
+ repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
423
+ freq_indices = repeated_freq_indices[freqs_per_band]
424
+
425
+ if stereo:
426
+ freq_indices = repeat(freq_indices, 'f -> f s', s=2)
427
+ freq_indices = freq_indices * 2 + torch.arange(2)
428
+ freq_indices = rearrange(freq_indices, 'f s -> (f s)')
429
+
430
+ self.register_buffer('freq_indices', freq_indices, persistent=False)
431
+ self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
432
+
433
+ num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
434
+ num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
435
+
436
+ self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
437
+ self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
438
+
439
+ # band split and mask estimator
440
+
441
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
442
+
443
+ self.band_split = BandSplit(
444
+ dim=dim,
445
+ dim_inputs=freqs_per_bands_with_complex
446
+ )
447
+
448
+ self.mask_estimators = nn.ModuleList([])
449
+
450
+ for _ in range(num_stems):
451
+ mask_estimator = MaskEstimator(
452
+ dim=dim,
453
+ dim_inputs=freqs_per_bands_with_complex,
454
+ depth=mask_estimator_depth
455
+ )
456
+
457
+ self.mask_estimators.append(mask_estimator)
458
+
459
+ # for the multi-resolution stft loss
460
+
461
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
462
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
463
+ self.multi_stft_n_fft = stft_n_fft
464
+ self.multi_stft_window_fn = multi_stft_window_fn
465
+
466
+ self.multi_stft_kwargs = dict(
467
+ hop_length=multi_stft_hop_size,
468
+ normalized=multi_stft_normalized
469
+ )
470
+
471
+ self.match_input_audio_length = match_input_audio_length
472
+
473
+ def forward(
474
+ self,
475
+ raw_audio,
476
+ target=None,
477
+ return_loss_breakdown=False
478
+ ):
479
+ """
480
+ einops
481
+
482
+ b - batch
483
+ f - freq
484
+ t - time
485
+ s - audio channel (1 for mono, 2 for stereo)
486
+ n - number of 'stems'
487
+ c - complex (2)
488
+ d - feature dimension
489
+ """
490
+
491
+ device = raw_audio.device
492
+
493
+ if raw_audio.ndim == 2:
494
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
495
+
496
+ batch, channels, raw_audio_length = raw_audio.shape
497
+
498
+ istft_length = raw_audio_length if self.match_input_audio_length else None
499
+
500
+ assert (not self.stereo and channels == 1) or (
501
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
502
+
503
+ # to stft
504
+
505
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
506
+
507
+ stft_window = self.stft_window_fn(device=device)
508
+
509
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
510
+ stft_repr = torch.view_as_real(stft_repr)
511
+
512
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
513
+ stft_repr = rearrange(stft_repr,
514
+ 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
515
+
516
+ # index out all frequencies for all frequency ranges across bands ascending in one go
517
+
518
+ batch_arange = torch.arange(batch, device=device)[..., None]
519
+
520
+ # account for stereo
521
+
522
+ x = stft_repr[batch_arange, self.freq_indices]
523
+
524
+ # fold the complex (real and imag) into the frequencies dimension
525
+
526
+ x = rearrange(x, 'b f t c -> b t (f c)')
527
+
528
+ x = self.band_split(x)
529
+
530
+ # axial / hierarchical attention
531
+
532
+ for transformer_block in self.layers:
533
+
534
+ if len(transformer_block) == 3:
535
+ linear_transformer, time_transformer, freq_transformer = transformer_block
536
+
537
+ x, ft_ps = pack([x], 'b * d')
538
+ x = linear_transformer(x)
539
+ x, = unpack(x, ft_ps, 'b * d')
540
+ else:
541
+ time_transformer, freq_transformer = transformer_block
542
+
543
+ x = rearrange(x, 'b t f d -> b f t d')
544
+ x, ps = pack([x], '* t d')
545
+
546
+ x = time_transformer(x)
547
+
548
+ x, = unpack(x, ps, '* t d')
549
+ x = rearrange(x, 'b f t d -> b t f d')
550
+ x, ps = pack([x], '* f d')
551
+
552
+ x = freq_transformer(x)
553
+
554
+ x, = unpack(x, ps, '* f d')
555
+
556
+ num_stems = len(self.mask_estimators)
557
+
558
+ masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
559
+ masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
560
+
561
+ # modulate frequency representation
562
+
563
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
564
+
565
+ # complex number multiplication
566
+
567
+ stft_repr = torch.view_as_complex(stft_repr)
568
+ masks = torch.view_as_complex(masks)
569
+
570
+ masks = masks.type(stft_repr.dtype)
571
+
572
+ # need to average the estimated mask for the overlapped frequencies
573
+
574
+ scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
575
+
576
+ stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
577
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
578
+
579
+ denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
580
+
581
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
582
+
583
+ # modulate stft repr with estimated mask
584
+
585
+ stft_repr = stft_repr * masks_averaged
586
+
587
+ # istft
588
+
589
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
590
+
591
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
592
+ length=istft_length)
593
+
594
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
595
+
596
+ if num_stems == 1:
597
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
598
+
599
+ # if a target is passed in, calculate loss for learning
600
+
601
+ if not exists(target):
602
+ return recon_audio
603
+
604
+ if self.num_stems > 1:
605
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
606
+
607
+ if target.ndim == 2:
608
+ target = rearrange(target, '... t -> ... 1 t')
609
+
610
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
611
+
612
+ loss = F.l1_loss(recon_audio, target)
613
+
614
+ multi_stft_resolution_loss = 0.
615
+
616
+ for window_size in self.multi_stft_resolutions_window_sizes:
617
+ res_stft_kwargs = dict(
618
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
619
+ win_length=window_size,
620
+ return_complex=True,
621
+ window=self.multi_stft_window_fn(window_size, device=device),
622
+ **self.multi_stft_kwargs,
623
+ )
624
+
625
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
626
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
627
+
628
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
629
+
630
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
631
+
632
+ total_loss = loss + weighted_multi_resolution_loss
633
+
634
+ if not return_loss_breakdown:
635
+ return total_loss
636
+
637
+ return total_loss, (loss, multi_stft_resolution_loss)
models/demucs4ht.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ import torch
8
+ import json
9
+ from omegaconf import OmegaConf
10
+ from demucs.demucs import Demucs
11
+ from demucs.hdemucs import HDemucs
12
+
13
+ import math
14
+ from openunmix.filtering import wiener
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+ from fractions import Fraction
18
+ from einops import rearrange
19
+
20
+ from demucs.transformer import CrossTransformerEncoder
21
+
22
+ from demucs.demucs import rescale_module
23
+ from demucs.states import capture_init
24
+ from demucs.spec import spectro, ispectro
25
+ from demucs.hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
26
+
27
+
28
+ class HTDemucs(nn.Module):
29
+ """
30
+ Spectrogram and hybrid Demucs model.
31
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
32
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
33
+ Frequency layers can still access information across time steps thanks to the DConv residual.
34
+
35
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
36
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
37
+
38
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
39
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
40
+ Open Unmix implementation [Stoter et al. 2019].
41
+
42
+ The loss is always on the temporal domain, by backpropagating through the above
43
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
44
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
45
+ contribution, without changing the one from the waveform, which will lead to worse performance.
46
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
47
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
48
+ hybrid models.
49
+
50
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
51
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
52
+
53
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
54
+ """
55
+
56
+ @capture_init
57
+ def __init__(
58
+ self,
59
+ sources,
60
+ # Channels
61
+ audio_channels=2,
62
+ channels=48,
63
+ channels_time=None,
64
+ growth=2,
65
+ # STFT
66
+ nfft=4096,
67
+ num_subbands=1,
68
+ wiener_iters=0,
69
+ end_iters=0,
70
+ wiener_residual=False,
71
+ cac=True,
72
+ # Main structure
73
+ depth=4,
74
+ rewrite=True,
75
+ # Frequency branch
76
+ multi_freqs=None,
77
+ multi_freqs_depth=3,
78
+ freq_emb=0.2,
79
+ emb_scale=10,
80
+ emb_smooth=True,
81
+ # Convolutions
82
+ kernel_size=8,
83
+ time_stride=2,
84
+ stride=4,
85
+ context=1,
86
+ context_enc=0,
87
+ # Normalization
88
+ norm_starts=4,
89
+ norm_groups=4,
90
+ # DConv residual branch
91
+ dconv_mode=1,
92
+ dconv_depth=2,
93
+ dconv_comp=8,
94
+ dconv_init=1e-3,
95
+ # Before the Transformer
96
+ bottom_channels=0,
97
+ # Transformer
98
+ t_layers=5,
99
+ t_emb="sin",
100
+ t_hidden_scale=4.0,
101
+ t_heads=8,
102
+ t_dropout=0.0,
103
+ t_max_positions=10000,
104
+ t_norm_in=True,
105
+ t_norm_in_group=False,
106
+ t_group_norm=False,
107
+ t_norm_first=True,
108
+ t_norm_out=True,
109
+ t_max_period=10000.0,
110
+ t_weight_decay=0.0,
111
+ t_lr=None,
112
+ t_layer_scale=True,
113
+ t_gelu=True,
114
+ t_weight_pos_embed=1.0,
115
+ t_sin_random_shift=0,
116
+ t_cape_mean_normalize=True,
117
+ t_cape_augment=True,
118
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
119
+ t_sparse_self_attn=False,
120
+ t_sparse_cross_attn=False,
121
+ t_mask_type="diag",
122
+ t_mask_random_seed=42,
123
+ t_sparse_attn_window=500,
124
+ t_global_window=100,
125
+ t_sparsity=0.95,
126
+ t_auto_sparsity=False,
127
+ # ------ Particuliar parameters
128
+ t_cross_first=False,
129
+ # Weight init
130
+ rescale=0.1,
131
+ # Metadata
132
+ samplerate=44100,
133
+ segment=10,
134
+ use_train_segment=False,
135
+ ):
136
+ """
137
+ Args:
138
+ sources (list[str]): list of source names.
139
+ audio_channels (int): input/output audio channels.
140
+ channels (int): initial number of hidden channels.
141
+ channels_time: if not None, use a different `channels` value for the time branch.
142
+ growth: increase the number of hidden channels by this factor at each layer.
143
+ nfft: number of fft bins. Note that changing this require careful computation of
144
+ various shape parameters and will not work out of the box for hybrid models.
145
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
146
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
147
+ wiener_residual: add residual source before wiener filtering.
148
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
149
+ in input and output. no further processing is done before ISTFT.
150
+ depth (int): number of layers in the encoder and in the decoder.
151
+ rewrite (bool): add 1x1 convolution to each layer.
152
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
153
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
154
+ layers will be wrapped.
155
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
156
+ the actual value controls the weight of the embedding.
157
+ emb_scale: equivalent to scaling the embedding learning rate
158
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
159
+ kernel_size: kernel_size for encoder and decoder layers.
160
+ stride: stride for encoder and decoder layers.
161
+ time_stride: stride for the final time layer, after the merge.
162
+ context: context for 1x1 conv in the decoder.
163
+ context_enc: context for 1x1 conv in the encoder.
164
+ norm_starts: layer at which group norm starts being used.
165
+ decoder layers are numbered in reverse order.
166
+ norm_groups: number of groups for group norm.
167
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
168
+ dconv_depth: depth of residual DConv branch.
169
+ dconv_comp: compression of DConv branch.
170
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
171
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
172
+ dconv_init: initial scale for the DConv branch LayerScale.
173
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
174
+ transformer in order to change the number of channels
175
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
176
+ t_emb: "sin", "cape" or "scaled"
177
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
178
+ for instance if C = 384 (the number of channels in the transformer) and
179
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
180
+ 384 * 4 = 1536
181
+ t_heads: number of heads for the transformer
182
+ t_dropout: dropout in the transformer
183
+ t_max_positions: max_positions for the "scaled" positional embedding, only
184
+ useful if t_emb="scaled"
185
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
186
+ transformer layers
187
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
188
+ timesteps (GroupNorm with group=1)
189
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
190
+ timesteps (GroupNorm with group=1)
191
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
192
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
193
+ t_max_period: (float) denominator in the sinusoidal embedding expression
194
+ t_weight_decay: (float) weight decay for the transformer
195
+ t_lr: (float) specific learning rate for the transformer
196
+ t_layer_scale: (bool) Layer Scale for the transformer
197
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
198
+ t_weight_pos_embed: (float) weighting of the positional embedding
199
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
200
+ see: https://arxiv.org/abs/2106.03143
201
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
202
+ during the inference, see: https://arxiv.org/abs/2106.03143
203
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
204
+ see: https://arxiv.org/abs/2106.03143
205
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
206
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
207
+ unless you designed really specific masks)
208
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
209
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
210
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
211
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
212
+ that generated the random part of the mask
213
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
214
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
215
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
216
+ and mask[:, :t_global_window] will be True
217
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
218
+ level of the random part of the mask.
219
+ t_cross_first: (bool) if True cross attention is the first layer of the
220
+ transformer (False seems to be better)
221
+ rescale: weight rescaling trick
222
+ use_train_segment: (bool) if True, the actual size that is used during the
223
+ training is used during inference.
224
+ """
225
+ super().__init__()
226
+ self.num_subbands = num_subbands
227
+ self.cac = cac
228
+ self.wiener_residual = wiener_residual
229
+ self.audio_channels = audio_channels
230
+ self.sources = sources
231
+ self.kernel_size = kernel_size
232
+ self.context = context
233
+ self.stride = stride
234
+ self.depth = depth
235
+ self.bottom_channels = bottom_channels
236
+ self.channels = channels
237
+ self.samplerate = samplerate
238
+ self.segment = segment
239
+ self.use_train_segment = use_train_segment
240
+ self.nfft = nfft
241
+ self.hop_length = nfft // 4
242
+ self.wiener_iters = wiener_iters
243
+ self.end_iters = end_iters
244
+ self.freq_emb = None
245
+ assert wiener_iters == end_iters
246
+
247
+ self.encoder = nn.ModuleList()
248
+ self.decoder = nn.ModuleList()
249
+
250
+ self.tencoder = nn.ModuleList()
251
+ self.tdecoder = nn.ModuleList()
252
+
253
+ chin = audio_channels
254
+ chin_z = chin # number of channels for the freq branch
255
+ if self.cac:
256
+ chin_z *= 2
257
+ if self.num_subbands > 1:
258
+ chin_z *= self.num_subbands
259
+ chout = channels_time or channels
260
+ chout_z = channels
261
+ freqs = nfft // 2
262
+
263
+ for index in range(depth):
264
+ norm = index >= norm_starts
265
+ freq = freqs > 1
266
+ stri = stride
267
+ ker = kernel_size
268
+ if not freq:
269
+ assert freqs == 1
270
+ ker = time_stride * 2
271
+ stri = time_stride
272
+
273
+ pad = True
274
+ last_freq = False
275
+ if freq and freqs <= kernel_size:
276
+ ker = freqs
277
+ pad = False
278
+ last_freq = True
279
+
280
+ kw = {
281
+ "kernel_size": ker,
282
+ "stride": stri,
283
+ "freq": freq,
284
+ "pad": pad,
285
+ "norm": norm,
286
+ "rewrite": rewrite,
287
+ "norm_groups": norm_groups,
288
+ "dconv_kw": {
289
+ "depth": dconv_depth,
290
+ "compress": dconv_comp,
291
+ "init": dconv_init,
292
+ "gelu": True,
293
+ },
294
+ }
295
+ kwt = dict(kw)
296
+ kwt["freq"] = 0
297
+ kwt["kernel_size"] = kernel_size
298
+ kwt["stride"] = stride
299
+ kwt["pad"] = True
300
+ kw_dec = dict(kw)
301
+ multi = False
302
+ if multi_freqs and index < multi_freqs_depth:
303
+ multi = True
304
+ kw_dec["context_freq"] = False
305
+
306
+ if last_freq:
307
+ chout_z = max(chout, chout_z)
308
+ chout = chout_z
309
+
310
+ enc = HEncLayer(
311
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
312
+ )
313
+ if freq:
314
+ tenc = HEncLayer(
315
+ chin,
316
+ chout,
317
+ dconv=dconv_mode & 1,
318
+ context=context_enc,
319
+ empty=last_freq,
320
+ **kwt
321
+ )
322
+ self.tencoder.append(tenc)
323
+
324
+ if multi:
325
+ enc = MultiWrap(enc, multi_freqs)
326
+ self.encoder.append(enc)
327
+ if index == 0:
328
+ chin = self.audio_channels * len(self.sources)
329
+ chin_z = chin
330
+ if self.cac:
331
+ chin_z *= 2
332
+ if self.num_subbands > 1:
333
+ chin_z *= self.num_subbands
334
+ dec = HDecLayer(
335
+ chout_z,
336
+ chin_z,
337
+ dconv=dconv_mode & 2,
338
+ last=index == 0,
339
+ context=context,
340
+ **kw_dec
341
+ )
342
+ if multi:
343
+ dec = MultiWrap(dec, multi_freqs)
344
+ if freq:
345
+ tdec = HDecLayer(
346
+ chout,
347
+ chin,
348
+ dconv=dconv_mode & 2,
349
+ empty=last_freq,
350
+ last=index == 0,
351
+ context=context,
352
+ **kwt
353
+ )
354
+ self.tdecoder.insert(0, tdec)
355
+ self.decoder.insert(0, dec)
356
+
357
+ chin = chout
358
+ chin_z = chout_z
359
+ chout = int(growth * chout)
360
+ chout_z = int(growth * chout_z)
361
+ if freq:
362
+ if freqs <= kernel_size:
363
+ freqs = 1
364
+ else:
365
+ freqs //= stride
366
+ if index == 0 and freq_emb:
367
+ self.freq_emb = ScaledEmbedding(
368
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
369
+ )
370
+ self.freq_emb_scale = freq_emb
371
+
372
+ if rescale:
373
+ rescale_module(self, reference=rescale)
374
+
375
+ transformer_channels = channels * growth ** (depth - 1)
376
+ if bottom_channels:
377
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
378
+ self.channel_downsampler = nn.Conv1d(
379
+ bottom_channels, transformer_channels, 1
380
+ )
381
+ self.channel_upsampler_t = nn.Conv1d(
382
+ transformer_channels, bottom_channels, 1
383
+ )
384
+ self.channel_downsampler_t = nn.Conv1d(
385
+ bottom_channels, transformer_channels, 1
386
+ )
387
+
388
+ transformer_channels = bottom_channels
389
+
390
+ if t_layers > 0:
391
+ self.crosstransformer = CrossTransformerEncoder(
392
+ dim=transformer_channels,
393
+ emb=t_emb,
394
+ hidden_scale=t_hidden_scale,
395
+ num_heads=t_heads,
396
+ num_layers=t_layers,
397
+ cross_first=t_cross_first,
398
+ dropout=t_dropout,
399
+ max_positions=t_max_positions,
400
+ norm_in=t_norm_in,
401
+ norm_in_group=t_norm_in_group,
402
+ group_norm=t_group_norm,
403
+ norm_first=t_norm_first,
404
+ norm_out=t_norm_out,
405
+ max_period=t_max_period,
406
+ weight_decay=t_weight_decay,
407
+ lr=t_lr,
408
+ layer_scale=t_layer_scale,
409
+ gelu=t_gelu,
410
+ sin_random_shift=t_sin_random_shift,
411
+ weight_pos_embed=t_weight_pos_embed,
412
+ cape_mean_normalize=t_cape_mean_normalize,
413
+ cape_augment=t_cape_augment,
414
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
415
+ sparse_self_attn=t_sparse_self_attn,
416
+ sparse_cross_attn=t_sparse_cross_attn,
417
+ mask_type=t_mask_type,
418
+ mask_random_seed=t_mask_random_seed,
419
+ sparse_attn_window=t_sparse_attn_window,
420
+ global_window=t_global_window,
421
+ sparsity=t_sparsity,
422
+ auto_sparsity=t_auto_sparsity,
423
+ )
424
+ else:
425
+ self.crosstransformer = None
426
+
427
+ def _spec(self, x):
428
+ hl = self.hop_length
429
+ nfft = self.nfft
430
+ x0 = x # noqa
431
+
432
+ # We re-pad the signal in order to keep the property
433
+ # that the size of the output is exactly the size of the input
434
+ # divided by the stride (here hop_length), when divisible.
435
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
436
+ # which is not supported by torch.stft.
437
+ # Having all convolution operations follow this convention allow to easily
438
+ # align the time and frequency branches later on.
439
+ assert hl == nfft // 4
440
+ le = int(math.ceil(x.shape[-1] / hl))
441
+ pad = hl // 2 * 3
442
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
443
+
444
+ z = spectro(x, nfft, hl)[..., :-1, :]
445
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
446
+ z = z[..., 2: 2 + le]
447
+ return z
448
+
449
+ def _ispec(self, z, length=None, scale=0):
450
+ hl = self.hop_length // (4**scale)
451
+ z = F.pad(z, (0, 0, 0, 1))
452
+ z = F.pad(z, (2, 2))
453
+ pad = hl // 2 * 3
454
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
455
+ x = ispectro(z, hl, length=le)
456
+ x = x[..., pad: pad + length]
457
+ return x
458
+
459
+ def _magnitude(self, z):
460
+ # return the magnitude of the spectrogram, except when cac is True,
461
+ # in which case we just move the complex dimension to the channel one.
462
+ if self.cac:
463
+ B, C, Fr, T = z.shape
464
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
465
+ m = m.reshape(B, C * 2, Fr, T)
466
+ else:
467
+ m = z.abs()
468
+ return m
469
+
470
+ def _mask(self, z, m):
471
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
472
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
473
+ niters = self.wiener_iters
474
+ if self.cac:
475
+ B, S, C, Fr, T = m.shape
476
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
477
+ out = torch.view_as_complex(out.contiguous())
478
+ return out
479
+ if self.training:
480
+ niters = self.end_iters
481
+ if niters < 0:
482
+ z = z[:, None]
483
+ return z / (1e-8 + z.abs()) * m
484
+ else:
485
+ return self._wiener(m, z, niters)
486
+
487
+ def _wiener(self, mag_out, mix_stft, niters):
488
+ # apply wiener filtering from OpenUnmix.
489
+ init = mix_stft.dtype
490
+ wiener_win_len = 300
491
+ residual = self.wiener_residual
492
+
493
+ B, S, C, Fq, T = mag_out.shape
494
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
495
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
496
+
497
+ outs = []
498
+ for sample in range(B):
499
+ pos = 0
500
+ out = []
501
+ for pos in range(0, T, wiener_win_len):
502
+ frame = slice(pos, pos + wiener_win_len)
503
+ z_out = wiener(
504
+ mag_out[sample, frame],
505
+ mix_stft[sample, frame],
506
+ niters,
507
+ residual=residual,
508
+ )
509
+ out.append(z_out.transpose(-1, -2))
510
+ outs.append(torch.cat(out, dim=0))
511
+ out = torch.view_as_complex(torch.stack(outs, 0))
512
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
513
+ if residual:
514
+ out = out[:, :-1]
515
+ assert list(out.shape) == [B, S, C, Fq, T]
516
+ return out.to(init)
517
+
518
+ def valid_length(self, length: int):
519
+ """
520
+ Return a length that is appropriate for evaluation.
521
+ In our case, always return the training length, unless
522
+ it is smaller than the given length, in which case this
523
+ raises an error.
524
+ """
525
+ if not self.use_train_segment:
526
+ return length
527
+ training_length = int(self.segment * self.samplerate)
528
+ if training_length < length:
529
+ raise ValueError(
530
+ f"Given length {length} is longer than "
531
+ f"training length {training_length}")
532
+ return training_length
533
+
534
+ def cac2cws(self, x):
535
+ k = self.num_subbands
536
+ b, c, f, t = x.shape
537
+ x = x.reshape(b, c, k, f // k, t)
538
+ x = x.reshape(b, c * k, f // k, t)
539
+ return x
540
+
541
+ def cws2cac(self, x):
542
+ k = self.num_subbands
543
+ b, c, f, t = x.shape
544
+ x = x.reshape(b, c // k, k, f, t)
545
+ x = x.reshape(b, c // k, f * k, t)
546
+ return x
547
+
548
+ def forward(self, mix):
549
+ length = mix.shape[-1]
550
+ length_pre_pad = None
551
+ if self.use_train_segment:
552
+ if self.training:
553
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
554
+ else:
555
+ training_length = int(self.segment * self.samplerate)
556
+ # print('Training length: {} Segment: {} Sample rate: {}'.format(training_length, self.segment, self.samplerate))
557
+ if mix.shape[-1] < training_length:
558
+ length_pre_pad = mix.shape[-1]
559
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
560
+ # print("Mix: {}".format(mix.shape))
561
+ # print("Length: {}".format(length))
562
+ z = self._spec(mix)
563
+ # print("Z: {} Type: {}".format(z.shape, z.dtype))
564
+ mag = self._magnitude(z)
565
+ x = mag
566
+ # print("MAG: {} Type: {}".format(x.shape, x.dtype))
567
+
568
+ if self.num_subbands > 1:
569
+ x = self.cac2cws(x)
570
+ # print("After SUBBANDS: {} Type: {}".format(x.shape, x.dtype))
571
+
572
+ B, C, Fq, T = x.shape
573
+
574
+ # unlike previous Demucs, we always normalize because it is easier.
575
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
576
+ std = x.std(dim=(1, 2, 3), keepdim=True)
577
+ x = (x - mean) / (1e-5 + std)
578
+ # x will be the freq. branch input.
579
+
580
+ # Prepare the time branch input.
581
+ xt = mix
582
+ meant = xt.mean(dim=(1, 2), keepdim=True)
583
+ stdt = xt.std(dim=(1, 2), keepdim=True)
584
+ xt = (xt - meant) / (1e-5 + stdt)
585
+
586
+ # print("XT: {}".format(xt.shape))
587
+
588
+ # okay, this is a giant mess I know...
589
+ saved = [] # skip connections, freq.
590
+ saved_t = [] # skip connections, time.
591
+ lengths = [] # saved lengths to properly remove padding, freq branch.
592
+ lengths_t = [] # saved lengths for time branch.
593
+ for idx, encode in enumerate(self.encoder):
594
+ lengths.append(x.shape[-1])
595
+ inject = None
596
+ if idx < len(self.tencoder):
597
+ # we have not yet merged branches.
598
+ lengths_t.append(xt.shape[-1])
599
+ tenc = self.tencoder[idx]
600
+ xt = tenc(xt)
601
+ # print("Encode XT {}: {}".format(idx, xt.shape))
602
+ if not tenc.empty:
603
+ # save for skip connection
604
+ saved_t.append(xt)
605
+ else:
606
+ # tenc contains just the first conv., so that now time and freq.
607
+ # branches have the same shape and can be merged.
608
+ inject = xt
609
+ x = encode(x, inject)
610
+ # print("Encode X {}: {}".format(idx, x.shape))
611
+ if idx == 0 and self.freq_emb is not None:
612
+ # add frequency embedding to allow for non equivariant convolutions
613
+ # over the frequency axis.
614
+ frs = torch.arange(x.shape[-2], device=x.device)
615
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
616
+ x = x + self.freq_emb_scale * emb
617
+
618
+ saved.append(x)
619
+ if self.crosstransformer:
620
+ if self.bottom_channels:
621
+ b, c, f, t = x.shape
622
+ x = rearrange(x, "b c f t-> b c (f t)")
623
+ x = self.channel_upsampler(x)
624
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
625
+ xt = self.channel_upsampler_t(xt)
626
+
627
+ x, xt = self.crosstransformer(x, xt)
628
+ # print("Cross Tran X {}, XT: {}".format(x.shape, xt.shape))
629
+
630
+ if self.bottom_channels:
631
+ x = rearrange(x, "b c f t-> b c (f t)")
632
+ x = self.channel_downsampler(x)
633
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
634
+ xt = self.channel_downsampler_t(xt)
635
+
636
+ for idx, decode in enumerate(self.decoder):
637
+ skip = saved.pop(-1)
638
+ x, pre = decode(x, skip, lengths.pop(-1))
639
+ # print('Decode {} X: {}'.format(idx, x.shape))
640
+ # `pre` contains the output just before final transposed convolution,
641
+ # which is used when the freq. and time branch separate.
642
+
643
+ offset = self.depth - len(self.tdecoder)
644
+ if idx >= offset:
645
+ tdec = self.tdecoder[idx - offset]
646
+ length_t = lengths_t.pop(-1)
647
+ if tdec.empty:
648
+ assert pre.shape[2] == 1, pre.shape
649
+ pre = pre[:, :, 0]
650
+ xt, _ = tdec(pre, None, length_t)
651
+ else:
652
+ skip = saved_t.pop(-1)
653
+ xt, _ = tdec(xt, skip, length_t)
654
+ # print('Decode {} XT: {}'.format(idx, xt.shape))
655
+
656
+ # Let's make sure we used all stored skip connections.
657
+ assert len(saved) == 0
658
+ assert len(lengths_t) == 0
659
+ assert len(saved_t) == 0
660
+
661
+ S = len(self.sources)
662
+
663
+ if self.num_subbands > 1:
664
+ x = x.view(B, -1, Fq, T)
665
+ # print("X view 1: {}".format(x.shape))
666
+ x = self.cws2cac(x)
667
+ # print("X view 2: {}".format(x.shape))
668
+
669
+ x = x.view(B, S, -1, Fq * self.num_subbands, T)
670
+ x = x * std[:, None] + mean[:, None]
671
+ # print("X returned: {}".format(x.shape))
672
+
673
+ zout = self._mask(z, x)
674
+ if self.use_train_segment:
675
+ if self.training:
676
+ x = self._ispec(zout, length)
677
+ else:
678
+ x = self._ispec(zout, training_length)
679
+ else:
680
+ x = self._ispec(zout, length)
681
+
682
+ if self.use_train_segment:
683
+ if self.training:
684
+ xt = xt.view(B, S, -1, length)
685
+ else:
686
+ xt = xt.view(B, S, -1, training_length)
687
+ else:
688
+ xt = xt.view(B, S, -1, length)
689
+ xt = xt * stdt[:, None] + meant[:, None]
690
+ x = xt + x
691
+ if length_pre_pad:
692
+ x = x[..., :length_pre_pad]
693
+ return x
694
+
695
+
696
+ def get_model(args):
697
+ extra = {
698
+ 'sources': list(args.training.instruments),
699
+ 'audio_channels': args.training.channels,
700
+ 'samplerate': args.training.samplerate,
701
+ # 'segment': args.model_segment or 4 * args.dset.segment,
702
+ 'segment': args.training.segment,
703
+ }
704
+ klass = {
705
+ 'demucs': Demucs,
706
+ 'hdemucs': HDemucs,
707
+ 'htdemucs': HTDemucs,
708
+ }[args.model]
709
+ kw = OmegaConf.to_container(getattr(args, args.model), resolve=True)
710
+ model = klass(**extra, **kw)
711
+ return model
712
+
713
+
models/mdx23c_tfc_tdf_v3.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+
7
+ class STFT:
8
+ def __init__(self, config):
9
+ self.n_fft = config.n_fft
10
+ self.hop_length = config.hop_length
11
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
12
+ self.dim_f = config.dim_f
13
+
14
+ def __call__(self, x):
15
+ window = self.window.to(x.device)
16
+ batch_dims = x.shape[:-2]
17
+ c, t = x.shape[-2:]
18
+ x = x.reshape([-1, t])
19
+ x = torch.stft(
20
+ x,
21
+ n_fft=self.n_fft,
22
+ hop_length=self.hop_length,
23
+ window=window,
24
+ center=True,
25
+ return_complex=True
26
+ )
27
+ x = torch.view_as_real(x)
28
+ x = x.permute([0, 3, 1, 2])
29
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
30
+ return x[..., :self.dim_f, :]
31
+
32
+ def inverse(self, x):
33
+ window = self.window.to(x.device)
34
+ batch_dims = x.shape[:-3]
35
+ c, f, t = x.shape[-3:]
36
+ n = self.n_fft // 2 + 1
37
+ f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
38
+ x = torch.cat([x, f_pad], -2)
39
+ x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
40
+ x = x.permute([0, 2, 3, 1])
41
+ x = x[..., 0] + x[..., 1] * 1.j
42
+ x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
43
+ x = x.reshape([*batch_dims, 2, -1])
44
+ return x
45
+
46
+
47
+ def get_norm(norm_type):
48
+ def norm(c, norm_type):
49
+ if norm_type == 'BatchNorm':
50
+ return nn.BatchNorm2d(c)
51
+ elif norm_type == 'InstanceNorm':
52
+ return nn.InstanceNorm2d(c, affine=True)
53
+ elif 'GroupNorm' in norm_type:
54
+ g = int(norm_type.replace('GroupNorm', ''))
55
+ return nn.GroupNorm(num_groups=g, num_channels=c)
56
+ else:
57
+ return nn.Identity()
58
+
59
+ return partial(norm, norm_type=norm_type)
60
+
61
+
62
+ def get_act(act_type):
63
+ if act_type == 'gelu':
64
+ return nn.GELU()
65
+ elif act_type == 'relu':
66
+ return nn.ReLU()
67
+ elif act_type[:3] == 'elu':
68
+ alpha = float(act_type.replace('elu', ''))
69
+ return nn.ELU(alpha)
70
+ else:
71
+ raise Exception
72
+
73
+
74
+ class Upscale(nn.Module):
75
+ def __init__(self, in_c, out_c, scale, norm, act):
76
+ super().__init__()
77
+ self.conv = nn.Sequential(
78
+ norm(in_c),
79
+ act,
80
+ nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
81
+ )
82
+
83
+ def forward(self, x):
84
+ return self.conv(x)
85
+
86
+
87
+ class Downscale(nn.Module):
88
+ def __init__(self, in_c, out_c, scale, norm, act):
89
+ super().__init__()
90
+ self.conv = nn.Sequential(
91
+ norm(in_c),
92
+ act,
93
+ nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
94
+ )
95
+
96
+ def forward(self, x):
97
+ return self.conv(x)
98
+
99
+
100
+ class TFC_TDF(nn.Module):
101
+ def __init__(self, in_c, c, l, f, bn, norm, act):
102
+ super().__init__()
103
+
104
+ self.blocks = nn.ModuleList()
105
+ for i in range(l):
106
+ block = nn.Module()
107
+
108
+ block.tfc1 = nn.Sequential(
109
+ norm(in_c),
110
+ act,
111
+ nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
112
+ )
113
+ block.tdf = nn.Sequential(
114
+ norm(c),
115
+ act,
116
+ nn.Linear(f, f // bn, bias=False),
117
+ norm(c),
118
+ act,
119
+ nn.Linear(f // bn, f, bias=False),
120
+ )
121
+ block.tfc2 = nn.Sequential(
122
+ norm(c),
123
+ act,
124
+ nn.Conv2d(c, c, 3, 1, 1, bias=False),
125
+ )
126
+ block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
127
+
128
+ self.blocks.append(block)
129
+ in_c = c
130
+
131
+ def forward(self, x):
132
+ for block in self.blocks:
133
+ s = block.shortcut(x)
134
+ x = block.tfc1(x)
135
+ x = x + block.tdf(x)
136
+ x = block.tfc2(x)
137
+ x = x + s
138
+ return x
139
+
140
+
141
+ class TFC_TDF_net(nn.Module):
142
+ def __init__(self, config):
143
+ super().__init__()
144
+ self.config = config
145
+
146
+ norm = get_norm(norm_type=config.model.norm)
147
+ act = get_act(act_type=config.model.act)
148
+
149
+ self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
150
+ self.num_subbands = config.model.num_subbands
151
+
152
+ dim_c = self.num_subbands * config.audio.num_channels * 2
153
+ n = config.model.num_scales
154
+ scale = config.model.scale
155
+ l = config.model.num_blocks_per_scale
156
+ c = config.model.num_channels
157
+ g = config.model.growth
158
+ bn = config.model.bottleneck_factor
159
+ f = config.audio.dim_f // self.num_subbands
160
+
161
+ self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
162
+
163
+ self.encoder_blocks = nn.ModuleList()
164
+ for i in range(n):
165
+ block = nn.Module()
166
+ block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
167
+ block.downscale = Downscale(c, c + g, scale, norm, act)
168
+ f = f // scale[1]
169
+ c += g
170
+ self.encoder_blocks.append(block)
171
+
172
+ self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
173
+
174
+ self.decoder_blocks = nn.ModuleList()
175
+ for i in range(n):
176
+ block = nn.Module()
177
+ block.upscale = Upscale(c, c - g, scale, norm, act)
178
+ f = f * scale[1]
179
+ c -= g
180
+ block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act)
181
+ self.decoder_blocks.append(block)
182
+
183
+ self.final_conv = nn.Sequential(
184
+ nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
185
+ act,
186
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
187
+ )
188
+
189
+ self.stft = STFT(config.audio)
190
+
191
+ def cac2cws(self, x):
192
+ k = self.num_subbands
193
+ b, c, f, t = x.shape
194
+ x = x.reshape(b, c, k, f // k, t)
195
+ x = x.reshape(b, c * k, f // k, t)
196
+ return x
197
+
198
+ def cws2cac(self, x):
199
+ k = self.num_subbands
200
+ b, c, f, t = x.shape
201
+ x = x.reshape(b, c // k, k, f, t)
202
+ x = x.reshape(b, c // k, f * k, t)
203
+ return x
204
+
205
+ def forward(self, x):
206
+
207
+ x = self.stft(x)
208
+
209
+ mix = x = self.cac2cws(x)
210
+
211
+ first_conv_out = x = self.first_conv(x)
212
+
213
+ x = x.transpose(-1, -2)
214
+
215
+ encoder_outputs = []
216
+ for block in self.encoder_blocks:
217
+ x = block.tfc_tdf(x)
218
+ encoder_outputs.append(x)
219
+ x = block.downscale(x)
220
+
221
+ x = self.bottleneck_block(x)
222
+
223
+ for block in self.decoder_blocks:
224
+ x = block.upscale(x)
225
+ x = torch.cat([x, encoder_outputs.pop()], 1)
226
+ x = block.tfc_tdf(x)
227
+
228
+ x = x.transpose(-1, -2)
229
+
230
+ x = x * first_conv_out # reduce artifacts
231
+
232
+ x = self.final_conv(torch.cat([mix, x], 1))
233
+
234
+ x = self.cws2cac(x)
235
+
236
+ if self.num_target_instruments > 1:
237
+ b, c, f, t = x.shape
238
+ x = x.reshape(b, self.num_target_instruments, -1, f, t)
239
+
240
+ x = self.stft.inverse(x)
241
+
242
+ return x
models/scnet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .scnet import SCNet
models/scnet/scnet.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import deque
5
+ from .separation import SeparationNet
6
+ import typing as tp
7
+ import math
8
+
9
+ class Swish(nn.Module):
10
+ def forward(self, x):
11
+ return x * x.sigmoid()
12
+
13
+
14
+ class ConvolutionModule(nn.Module):
15
+ """
16
+ Convolution Module in SD block.
17
+
18
+ Args:
19
+ channels (int): input/output channels.
20
+ depth (int): number of layers in the residual branch. Each layer has its own
21
+ compress (float): amount of channel compression.
22
+ kernel (int): kernel size for the convolutions.
23
+ """
24
+ def __init__(self, channels, depth=2, compress=4, kernel=3):
25
+ super().__init__()
26
+ assert kernel % 2 == 1
27
+ self.depth = abs(depth)
28
+ hidden_size = int(channels / compress)
29
+ norm = lambda d: nn.GroupNorm(1, d)
30
+ self.layers = nn.ModuleList([])
31
+ for _ in range(self.depth):
32
+ padding = (kernel // 2)
33
+ mods = [
34
+ norm(channels),
35
+ nn.Conv1d(channels, hidden_size*2, kernel, padding = padding),
36
+ nn.GLU(1),
37
+ nn.Conv1d(hidden_size, hidden_size, kernel, padding = padding, groups = hidden_size),
38
+ norm(hidden_size),
39
+ Swish(),
40
+ nn.Conv1d(hidden_size, channels, 1),
41
+ ]
42
+ layer = nn.Sequential(*mods)
43
+ self.layers.append(layer)
44
+
45
+ def forward(self, x):
46
+ for layer in self.layers:
47
+ x = x + layer(x)
48
+ return x
49
+
50
+
51
+ class FusionLayer(nn.Module):
52
+ """
53
+ A FusionLayer within the decoder.
54
+
55
+ Args:
56
+ - channels (int): Number of input channels.
57
+ - kernel_size (int, optional): Kernel size for the convolutional layer, defaults to 3.
58
+ - stride (int, optional): Stride for the convolutional layer, defaults to 1.
59
+ - padding (int, optional): Padding for the convolutional layer, defaults to 1.
60
+ """
61
+
62
+ def __init__(self, channels, kernel_size=3, stride=1, padding=1):
63
+ super(FusionLayer, self).__init__()
64
+ self.conv = nn.Conv2d(channels * 2, channels * 2, kernel_size, stride=stride, padding=padding)
65
+
66
+ def forward(self, x, skip=None):
67
+ if skip is not None:
68
+ x += skip
69
+ x = x.repeat(1, 2, 1, 1)
70
+ x = self.conv(x)
71
+ x = F.glu(x, dim=1)
72
+ return x
73
+
74
+
75
+ class SDlayer(nn.Module):
76
+ """
77
+ Implements a Sparse Down-sample Layer for processing different frequency bands separately.
78
+
79
+ Args:
80
+ - channels_in (int): Input channel count.
81
+ - channels_out (int): Output channel count.
82
+ - band_configs (dict): A dictionary containing configuration for each frequency band.
83
+ Keys are 'low', 'mid', 'high' for each band, and values are
84
+ dictionaries with keys 'SR', 'stride', and 'kernel' for proportion,
85
+ stride, and kernel size, respectively.
86
+ """
87
+ def __init__(self, channels_in, channels_out, band_configs):
88
+ super(SDlayer, self).__init__()
89
+
90
+ # Initializing convolutional layers for each band
91
+ self.convs = nn.ModuleList()
92
+ self.strides = []
93
+ self.kernels = []
94
+ for config in band_configs.values():
95
+ self.convs.append(nn.Conv2d(channels_in, channels_out, (config['kernel'], 1), (config['stride'], 1), (0, 0)))
96
+ self.strides.append(config['stride'])
97
+ self.kernels.append(config['kernel'])
98
+
99
+ # Saving rate proportions for determining splits
100
+ self.SR_low = band_configs['low']['SR']
101
+ self.SR_mid = band_configs['mid']['SR']
102
+
103
+ def forward(self, x):
104
+ B, C, Fr, T = x.shape
105
+ # Define splitting points based on sampling rates
106
+ splits = [
107
+ (0, math.ceil(Fr * self.SR_low)),
108
+ (math.ceil(Fr * self.SR_low), math.ceil(Fr * (self.SR_low + self.SR_mid))),
109
+ (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr)
110
+ ]
111
+
112
+ # Processing each band with the corresponding convolution
113
+ outputs = []
114
+ original_lengths=[]
115
+ for conv, stride, kernel, (start, end) in zip(self.convs, self.strides, self.kernels, splits):
116
+ extracted = x[:, :, start:end, :]
117
+ original_lengths.append(end-start)
118
+ current_length = extracted.shape[2]
119
+
120
+ # padding
121
+ if stride == 1:
122
+ total_padding = kernel - stride
123
+ else:
124
+ total_padding = (stride - current_length % stride) % stride
125
+ pad_left = total_padding // 2
126
+ pad_right = total_padding - pad_left
127
+
128
+ padded = F.pad(extracted, (0, 0, pad_left, pad_right))
129
+
130
+ output = conv(padded)
131
+ outputs.append(output)
132
+
133
+ return outputs, original_lengths
134
+
135
+
136
+ class SUlayer(nn.Module):
137
+ """
138
+ Implements a Sparse Up-sample Layer in decoder.
139
+
140
+ Args:
141
+ - channels_in: The number of input channels.
142
+ - channels_out: The number of output channels.
143
+ - convtr_configs: Dictionary containing the configurations for transposed convolutions.
144
+ """
145
+ def __init__(self, channels_in, channels_out, band_configs):
146
+ super(SUlayer, self).__init__()
147
+
148
+ # Initializing convolutional layers for each band
149
+ self.convtrs = nn.ModuleList([
150
+ nn.ConvTranspose2d(channels_in, channels_out, [config['kernel'], 1], [config['stride'], 1])
151
+ for _, config in band_configs.items()
152
+ ])
153
+
154
+ def forward(self, x, lengths, origin_lengths):
155
+ B, C, Fr, T = x.shape
156
+ # Define splitting points based on input lengths
157
+ splits = [
158
+ (0, lengths[0]),
159
+ (lengths[0], lengths[0] + lengths[1]),
160
+ (lengths[0] + lengths[1], None)
161
+ ]
162
+ # Processing each band with the corresponding convolution
163
+ outputs = []
164
+ for idx, (convtr, (start, end)) in enumerate(zip(self.convtrs, splits)):
165
+ out = convtr(x[:, :, start:end, :])
166
+ # Calculate the distance to trim the output symmetrically to original length
167
+ current_Fr_length = out.shape[2]
168
+ dist = abs(origin_lengths[idx] - current_Fr_length) // 2
169
+
170
+ # Trim the output to the original length symmetrically
171
+ trimmed_out = out[:, :, dist:dist + origin_lengths[idx], :]
172
+
173
+ outputs.append(trimmed_out)
174
+
175
+ # Concatenate trimmed outputs along the frequency dimension to return the final tensor
176
+ x = torch.cat(outputs, dim=2)
177
+
178
+ return x
179
+
180
+
181
+ class SDblock(nn.Module):
182
+ """
183
+ Implements a simplified Sparse Down-sample block in encoder.
184
+
185
+ Args:
186
+ - channels_in (int): Number of input channels.
187
+ - channels_out (int): Number of output channels.
188
+ - band_config (dict): Configuration for the SDlayer specifying band splits and convolutions.
189
+ - conv_config (dict): Configuration for convolution modules applied to each band.
190
+ - depths (list of int): List specifying the convolution depths for low, mid, and high frequency bands.
191
+ """
192
+ def __init__(self, channels_in, channels_out, band_configs={}, conv_config={}, depths=[3, 2, 1], kernel_size=3):
193
+ super(SDblock, self).__init__()
194
+ self.SDlayer = SDlayer(channels_in, channels_out, band_configs)
195
+
196
+ # Dynamically create convolution modules for each band based on depths
197
+ self.conv_modules = nn.ModuleList([
198
+ ConvolutionModule(channels_out, depth, **conv_config) for depth in depths
199
+ ])
200
+ #Set the kernel_size to an odd number.
201
+ self.globalconv = nn.Conv2d(channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2)
202
+
203
+ def forward(self, x):
204
+ bands, original_lengths = self.SDlayer(x)
205
+ # B, C, f, T = band.shape
206
+ bands = [
207
+ F.gelu(
208
+ conv(band.permute(0, 2, 1, 3).reshape(-1, band.shape[1], band.shape[3]))
209
+ .view(band.shape[0], band.shape[2], band.shape[1], band.shape[3])
210
+ .permute(0, 2, 1, 3)
211
+ )
212
+ for conv, band in zip(self.conv_modules, bands)
213
+
214
+ ]
215
+ lengths = [band.size(-2) for band in bands]
216
+ full_band = torch.cat(bands, dim=2)
217
+ skip = full_band
218
+
219
+ output = self.globalconv(full_band)
220
+
221
+ return output, skip, lengths, original_lengths
222
+
223
+
224
+ class SCNet(nn.Module):
225
+ """
226
+ The implementation of SCNet: Sparse Compression Network for Music Source Separation. Paper: https://arxiv.org/abs/2401.13276.pdf
227
+
228
+ Args:
229
+ - sources (List[str]): List of sources to be separated.
230
+ - audio_channels (int): Number of audio channels.
231
+ - nfft (int): Number of FFTs to determine the frequency dimension of the input.
232
+ - hop_size (int): Hop size for the STFT.
233
+ - win_size (int): Window size for STFT.
234
+ - normalized (bool): Whether to normalize the STFT.
235
+ - dims (List[int]): List of channel dimensions for each block.
236
+ - band_configs (Dict[str, Dict[str, int]]): Configuration for each frequency band, including how to divide the frequency bands,
237
+ and the settings for the upsampling/downsampling convolutional layers.
238
+ - conv_depths (List[int]): List specifying the number of convolution modules in each SD block.
239
+ - compress (int): Compression factor for convolution module.
240
+ - conv_kernel (int): Kernel size for convolution layer in convolution module.
241
+ - num_dplayer (int): Number of dual-path layers.
242
+ - expand (int): Expansion factor in the dual-path RNN, default is 1.
243
+
244
+ """
245
+ def __init__(self,
246
+ sources = ['drums', 'bass', 'other', 'vocals'],
247
+ audio_channels = 2,
248
+ # Main structure
249
+ dims = [4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large
250
+ # STFT
251
+ nfft = 4096,
252
+ hop_size = 1024,
253
+ win_size = 4096,
254
+ normalized = True,
255
+ # SD/SU layer
256
+ band_configs = {
257
+ 'low': { 'SR': .175, 'stride': 1, 'kernel': 3 },
258
+ 'mid': { 'SR': .392, 'stride': 4, 'kernel': 4 },
259
+ 'high': {'SR': .433, 'stride': 16, 'kernel': 16 }
260
+ },
261
+ # Convolution Module
262
+ conv_depths = [3,2,1],
263
+ compress = 4,
264
+ conv_kernel = 3,
265
+ # Dual-path RNN
266
+ num_dplayer = 6,
267
+ expand = 1,
268
+ # mamba
269
+ use_mamba = False,
270
+ mamba_config = {
271
+ 'd_stat': 16,
272
+ 'd_conv': 4,
273
+ 'd_expand': 2
274
+ }):
275
+ super().__init__()
276
+ self.sources = sources
277
+ self.audio_channels = audio_channels
278
+ self.dims = dims
279
+ self.band_configs = band_configs
280
+ self.hop_length = hop_size
281
+ self.conv_config = {
282
+ 'compress': compress,
283
+ 'kernel': conv_kernel,
284
+ }
285
+
286
+ self.stft_config = {
287
+ 'n_fft': nfft,
288
+ 'hop_length': hop_size,
289
+ 'win_length': win_size,
290
+ 'center': True,
291
+ 'normalized': normalized
292
+ }
293
+
294
+ self.encoder = nn.ModuleList()
295
+ self.decoder = nn.ModuleList()
296
+
297
+ for index in range(len(dims)-1):
298
+ enc = SDblock(
299
+ channels_in = dims[index],
300
+ channels_out = dims[index+1],
301
+ band_configs = self.band_configs,
302
+ conv_config = self.conv_config,
303
+ depths = conv_depths
304
+ )
305
+ self.encoder.append(enc)
306
+
307
+ dec = nn.Sequential(
308
+ FusionLayer(channels = dims[index+1]),
309
+ SUlayer(
310
+ channels_in = dims[index+1],
311
+ channels_out = dims[index] if index != 0 else dims[index] * len(sources),
312
+ band_configs = self.band_configs,
313
+ )
314
+ )
315
+ self.decoder.insert(0, dec)
316
+
317
+ self.separation_net = SeparationNet(
318
+ channels = dims[-1],
319
+ expand = expand,
320
+ num_layers = num_dplayer,
321
+ use_mamba = use_mamba,
322
+ **mamba_config
323
+ )
324
+
325
+
326
+ def forward(self, x):
327
+ # B, C, L = x.shape
328
+ B = x.shape[0]
329
+ # In the initial padding, ensure that the number of frames after the STFT (the length of the T dimension) is even,
330
+ # so that the RFFT operation can be used in the separation network.
331
+ padding = self.hop_length - x.shape[-1] % self.hop_length
332
+ if (x.shape[-1] + padding) // self.hop_length % 2 == 0:
333
+ padding += self.hop_length
334
+ x = F.pad(x, (0, padding))
335
+
336
+ # STFT
337
+ L = x.shape[-1]
338
+ x = x.reshape(-1, L)
339
+ x = torch.stft(x, **self.stft_config, return_complex=True)
340
+ x = torch.view_as_real(x)
341
+ x = x.permute(0, 3, 1, 2).reshape(x.shape[0]//self.audio_channels, x.shape[3]*self.audio_channels, x.shape[1], x.shape[2])
342
+
343
+ B, C, Fr, T = x.shape
344
+
345
+ save_skip = deque()
346
+ save_lengths = deque()
347
+ save_original_lengths = deque()
348
+ # encoder
349
+ for sd_layer in self.encoder:
350
+ x, skip, lengths, original_lengths = sd_layer(x)
351
+ save_skip.append(skip)
352
+ save_lengths.append(lengths)
353
+ save_original_lengths.append(original_lengths)
354
+
355
+ #separation
356
+ x = self.separation_net(x)
357
+
358
+ #decoder
359
+ for fusion_layer, su_layer in self.decoder:
360
+ x = fusion_layer(x, save_skip.pop())
361
+ x = su_layer(x, save_lengths.pop(), save_original_lengths.pop())
362
+
363
+ #output
364
+ n = self.dims[0]
365
+ x = x.view(B, n, -1, Fr, T)
366
+ x = x.reshape(-1, 2, Fr, T).permute(0, 2, 3, 1)
367
+ x = torch.view_as_complex(x.contiguous())
368
+ x = torch.istft(x, **self.stft_config)
369
+ x = x.reshape(B, len(self.sources), self.audio_channels, -1)
370
+
371
+ x = x[:, :, :, :-padding]
372
+
373
+ return x
models/scnet/separation.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.modules.rnn import LSTM
4
+ import torch.nn.functional as Func
5
+ try:
6
+ from mamba_ssm.modules.mamba_simple import Mamba
7
+ except Exception as e:
8
+ print('No mamba found. Please install mamba_ssm')
9
+
10
+ class RMSNorm(nn.Module):
11
+ def __init__(self, dim):
12
+ super().__init__()
13
+ self.scale = dim ** 0.5
14
+ self.gamma = nn.Parameter(torch.ones(dim))
15
+
16
+ def forward(self, x):
17
+ return Func.normalize(x, dim=-1) * self.scale * self.gamma
18
+
19
+
20
+ class MambaModule(nn.Module):
21
+ def __init__(self, d_model, d_state, d_conv, d_expand):
22
+ super().__init__()
23
+ self.norm = RMSNorm(dim=d_model)
24
+ self.mamba = Mamba(
25
+ d_model=d_model,
26
+ d_state=d_state,
27
+ d_conv=d_conv,
28
+ expand=d_expand
29
+ )
30
+
31
+ def forward(self, x):
32
+ x = x + self.mamba(self.norm(x))
33
+ return x
34
+
35
+
36
+ class FeatureConversion(nn.Module):
37
+ """
38
+ Integrates into the adjacent Dual-Path layer.
39
+
40
+ Args:
41
+ channels (int): Number of input channels.
42
+ inverse (bool): If True, uses ifft; otherwise, uses rfft.
43
+ """
44
+ def __init__(self, channels, inverse):
45
+ super().__init__()
46
+ self.inverse = inverse
47
+ self.channels= channels
48
+
49
+ def forward(self, x):
50
+ # B, C, F, T = x.shape
51
+ if self.inverse:
52
+ x = x.float()
53
+ x_r = x[:, :self.channels//2, :, :]
54
+ x_i = x[:, self.channels//2:, :, :]
55
+ x = torch.complex(x_r, x_i)
56
+ x = torch.fft.irfft(x, dim=3, norm="ortho")
57
+ else:
58
+ x = x.float()
59
+ x = torch.fft.rfft(x, dim=3, norm="ortho")
60
+ x_real = x.real
61
+ x_imag = x.imag
62
+ x = torch.cat([x_real, x_imag], dim=1)
63
+ return x
64
+
65
+
66
+ class DualPathRNN(nn.Module):
67
+ """
68
+ Dual-Path RNN in Separation Network.
69
+
70
+ Args:
71
+ d_model (int): The number of expected features in the input (input_size).
72
+ expand (int): Expansion factor used to calculate the hidden_size of LSTM.
73
+ bidirectional (bool): If True, becomes a bidirectional LSTM.
74
+ """
75
+ def __init__(self, d_model, expand, bidirectional=True):
76
+ super(DualPathRNN, self).__init__()
77
+
78
+ self.d_model = d_model
79
+ self.hidden_size = d_model * expand
80
+ self.bidirectional = bidirectional
81
+ # Initialize LSTM layers and normalization layers
82
+ self.lstm_layers = nn.ModuleList([self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)])
83
+ self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_size*2, self.d_model) for _ in range(2)])
84
+ self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)])
85
+
86
+ def _init_lstm_layer(self, d_model, hidden_size):
87
+ return LSTM(d_model, hidden_size, num_layers=1, bidirectional=self.bidirectional, batch_first=True)
88
+
89
+ def forward(self, x):
90
+ B, C, F, T = x.shape
91
+
92
+ # Process dual-path rnn
93
+
94
+ original_x = x
95
+ # Frequency-path
96
+ x = self.norm_layers[0](x)
97
+ x = x.transpose(1, 3).contiguous().view(B * T, F, C)
98
+ x, _ = self.lstm_layers[0](x)
99
+ x = self.linear_layers[0](x)
100
+ x = x.view(B, T, F, C).transpose(1, 3)
101
+ x = x + original_x
102
+
103
+ original_x = x
104
+ # Time-path
105
+ x = self.norm_layers[1](x)
106
+ x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2)
107
+ x, _ = self.lstm_layers[1](x)
108
+ x = self.linear_layers[1](x)
109
+ x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2)
110
+ x = x + original_x
111
+
112
+ return x
113
+
114
+
115
+ class DualPathMamba(nn.Module):
116
+ """
117
+ Dual-Path Mamba.
118
+
119
+ """
120
+ def __init__(self, d_model, d_stat, d_conv, d_expand):
121
+ super(DualPathMamba, self).__init__()
122
+ # Initialize mamba layers
123
+ self.mamba_layers = nn.ModuleList([MambaModule(d_model, d_stat, d_conv, d_expand) for _ in range(2)])
124
+
125
+ def forward(self, x):
126
+ B, C, F, T = x.shape
127
+
128
+ # Process dual-path mamba
129
+
130
+ # Frequency-path
131
+ x = x.transpose(1, 3).contiguous().view(B * T, F, C)
132
+ x = self.mamba_layers[0](x)
133
+ x = x.view(B, T, F, C).transpose(1, 3)
134
+
135
+ # Time-path
136
+ x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2)
137
+ x = self.mamba_layers[1](x)
138
+ x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2)
139
+
140
+ return x
141
+
142
+
143
+ class SeparationNet(nn.Module):
144
+ """
145
+ Implements a simplified Sparse Down-sample block in an encoder architecture.
146
+
147
+ Args:
148
+ - channels (int): Number input channels.
149
+ - expand (int): Expansion factor used to calculate the hidden_size of LSTM.
150
+ - num_layers (int): Number of dual-path layers.
151
+ - use_mamba (bool): If true, use the Mamba module to replace the RNN.
152
+ - d_stat (int), d_conv (int), d_expand (int): These are built-in parameters of the Mamba model.
153
+ """
154
+ def __init__(self, channels, expand=1, num_layers=6, use_mamba=True, d_stat=16, d_conv=4, d_expand=2):
155
+ super(SeparationNet, self).__init__()
156
+
157
+ self.num_layers = num_layers
158
+ if use_mamba:
159
+ self.dp_modules = nn.ModuleList([
160
+ DualPathMamba(channels * (2 if i % 2 == 1 else 1), d_stat, d_conv, d_expand * (2 if i % 2 == 1 else 1)) for i in range(num_layers)
161
+ ])
162
+ else:
163
+ self.dp_modules = nn.ModuleList([
164
+ DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) for i in range(num_layers)
165
+ ])
166
+
167
+ self.feature_conversion = nn.ModuleList([
168
+ FeatureConversion(channels * 2 , inverse = False if i % 2 == 0 else True) for i in range(num_layers)
169
+ ])
170
+ def forward(self, x):
171
+ for i in range(self.num_layers):
172
+ x = self.dp_modules[i](x)
173
+ x = self.feature_conversion[i](x)
174
+ return x
175
+
176
+
177
+
178
+
models/scnet_unofficial/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from models.scnet_unofficial.scnet import SCNet
models/scnet_unofficial/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from models.scnet_unofficial.modules.dualpath_rnn import DualPathRNN
2
+ from models.scnet_unofficial.modules.sd_encoder import SDBlock
3
+ from models.scnet_unofficial.modules.su_decoder import SUBlock
models/scnet_unofficial/modules/dualpath_rnn.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as Func
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(self, dim):
7
+ super().__init__()
8
+ self.scale = dim ** 0.5
9
+ self.gamma = nn.Parameter(torch.ones(dim))
10
+
11
+ def forward(self, x):
12
+ return Func.normalize(x, dim=-1) * self.scale * self.gamma
13
+
14
+
15
+ class MambaModule(nn.Module):
16
+ def __init__(self, d_model, d_state, d_conv, d_expand):
17
+ super().__init__()
18
+ self.norm = RMSNorm(dim=d_model)
19
+ self.mamba = Mamba(
20
+ d_model=d_model,
21
+ d_state=d_state,
22
+ d_conv=d_conv,
23
+ d_expand=d_expand
24
+ )
25
+
26
+ def forward(self, x):
27
+ x = x + self.mamba(self.norm(x))
28
+ return x
29
+
30
+
31
+ class RNNModule(nn.Module):
32
+ """
33
+ RNNModule class implements a recurrent neural network module with LSTM cells.
34
+
35
+ Args:
36
+ - input_dim (int): Dimensionality of the input features.
37
+ - hidden_dim (int): Dimensionality of the hidden state of the LSTM.
38
+ - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True.
39
+
40
+ Shapes:
41
+ - Input: (B, T, D) where
42
+ B is batch size,
43
+ T is sequence length,
44
+ D is input dimensionality.
45
+ - Output: (B, T, D) where
46
+ B is batch size,
47
+ T is sequence length,
48
+ D is input dimensionality.
49
+ """
50
+
51
+ def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True):
52
+ """
53
+ Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag.
54
+ """
55
+ super().__init__()
56
+ self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim)
57
+ self.rnn = nn.LSTM(
58
+ input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional
59
+ )
60
+ self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim)
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ """
64
+ Performs forward pass through the RNNModule.
65
+
66
+ Args:
67
+ - x (torch.Tensor): Input tensor of shape (B, T, D).
68
+
69
+ Returns:
70
+ - torch.Tensor: Output tensor of shape (B, T, D).
71
+ """
72
+ x = x.transpose(1, 2)
73
+ x = self.groupnorm(x)
74
+ x = x.transpose(1, 2)
75
+
76
+ x, (hidden, _) = self.rnn(x)
77
+ x = self.fc(x)
78
+ return x
79
+
80
+
81
+ class RFFTModule(nn.Module):
82
+ """
83
+ RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT)
84
+ or its inverse on input tensors.
85
+
86
+ Args:
87
+ - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False.
88
+
89
+ Shapes:
90
+ - Input: (B, F, T, D) where
91
+ B is batch size,
92
+ F is the number of features,
93
+ T is sequence length,
94
+ D is input dimensionality.
95
+ - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT.
96
+ (B, F, T, D // 2, 2) if performing inverse FFT.
97
+ """
98
+
99
+ def __init__(self, inverse: bool = False):
100
+ """
101
+ Initializes RFFTModule with inverse flag.
102
+ """
103
+ super().__init__()
104
+ self.inverse = inverse
105
+
106
+ def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor:
107
+ """
108
+ Performs forward or inverse FFT on the input tensor x.
109
+
110
+ Args:
111
+ - x (torch.Tensor): Input tensor of shape (B, F, T, D).
112
+ - time_dim (int): Input size of time dimension.
113
+
114
+ Returns:
115
+ - torch.Tensor: Output tensor after FFT or its inverse operation.
116
+ """
117
+ dtype = x.dtype
118
+ B, F, T, D = x.shape
119
+
120
+ # RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision
121
+ x = x.float()
122
+
123
+ if not self.inverse:
124
+ x = torch.fft.rfft(x, dim=2)
125
+ x = torch.view_as_real(x)
126
+ x = x.reshape(B, F, T // 2 + 1, D * 2)
127
+ else:
128
+ x = x.reshape(B, F, T, D // 2, 2)
129
+ x = torch.view_as_complex(x)
130
+ x = torch.fft.irfft(x, n=time_dim, dim=2)
131
+
132
+ x = x.to(dtype)
133
+ return x
134
+
135
+ def extra_repr(self) -> str:
136
+ """
137
+ Returns extra representation string with module's configuration.
138
+ """
139
+ return f"inverse={self.inverse}"
140
+
141
+
142
+ class DualPathRNN(nn.Module):
143
+ """
144
+ DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule.
145
+
146
+ Args:
147
+ - n_layers (int): Number of layers in the network.
148
+ - input_dim (int): Dimensionality of the input features.
149
+ - hidden_dim (int): Dimensionality of the hidden state of the RNNModule.
150
+
151
+ Shapes:
152
+ - Input: (B, F, T, D) where
153
+ B is batch size,
154
+ F is the number of features (frequency dimension),
155
+ T is sequence length (time dimension),
156
+ D is input dimensionality (channel dimension).
157
+ - Output: (B, F, T, D) where
158
+ B is batch size,
159
+ F is the number of features (frequency dimension),
160
+ T is sequence length (time dimension),
161
+ D is input dimensionality (channel dimension).
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ n_layers: int,
167
+ input_dim: int,
168
+ hidden_dim: int,
169
+
170
+ use_mamba: bool = False,
171
+ d_state: int = 16,
172
+ d_conv: int = 4,
173
+ d_expand: int = 2
174
+ ):
175
+ """
176
+ Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension.
177
+ """
178
+ super().__init__()
179
+
180
+ if use_mamba:
181
+ from mamba_ssm.modules.mamba_simple import Mamba
182
+ net = MambaModule
183
+ dkwargs = {"d_model": input_dim, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand}
184
+ ukwargs = {"d_model": input_dim * 2, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand * 2}
185
+ else:
186
+ net = RNNModule
187
+ dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim}
188
+ ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2}
189
+
190
+ self.layers = nn.ModuleList()
191
+ for i in range(1, n_layers + 1):
192
+ kwargs = dkwargs if i % 2 == 1 else ukwargs
193
+ layer = nn.ModuleList([
194
+ net(**kwargs),
195
+ net(**kwargs),
196
+ RFFTModule(inverse=(i % 2 == 0)),
197
+ ])
198
+ self.layers.append(layer)
199
+
200
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
201
+ """
202
+ Performs forward pass through the DualPathRNN.
203
+
204
+ Args:
205
+ - x (torch.Tensor): Input tensor of shape (B, F, T, D).
206
+
207
+ Returns:
208
+ - torch.Tensor: Output tensor of shape (B, F, T, D).
209
+ """
210
+
211
+ time_dim = x.shape[2]
212
+
213
+ for time_layer, freq_layer, rfft_layer in self.layers:
214
+ B, F, T, D = x.shape
215
+
216
+ x = x.reshape((B * F), T, D)
217
+ x = time_layer(x)
218
+ x = x.reshape(B, F, T, D)
219
+ x = x.permute(0, 2, 1, 3)
220
+
221
+ x = x.reshape((B * T), F, D)
222
+ x = freq_layer(x)
223
+ x = x.reshape(B, T, F, D)
224
+ x = x.permute(0, 2, 1, 3)
225
+
226
+ x = rfft_layer(x, time_dim)
227
+
228
+ return x
models/scnet_unofficial/modules/sd_encoder.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from models.scnet_unofficial.utils import create_intervals
7
+
8
+
9
+ class Downsample(nn.Module):
10
+ """
11
+ Downsample class implements a module for downsampling input tensors using 2D convolution.
12
+
13
+ Args:
14
+ - input_dim (int): Dimensionality of the input channels.
15
+ - output_dim (int): Dimensionality of the output channels.
16
+ - stride (int): Stride value for the convolution operation.
17
+
18
+ Shapes:
19
+ - Input: (B, C_in, F, T) where
20
+ B is batch size,
21
+ C_in is the number of input channels,
22
+ F is the frequency dimension,
23
+ T is the time dimension.
24
+ - Output: (B, C_out, F // stride, T) where
25
+ B is batch size,
26
+ C_out is the number of output channels,
27
+ F // stride is the downsampled frequency dimension.
28
+
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ input_dim: int,
34
+ output_dim: int,
35
+ stride: int,
36
+ ):
37
+ """
38
+ Initializes Downsample with input dimension, output dimension, and stride.
39
+ """
40
+ super().__init__()
41
+ self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1))
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ Performs forward pass through the Downsample module.
46
+
47
+ Args:
48
+ - x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
49
+
50
+ Returns:
51
+ - torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T).
52
+ """
53
+ return self.conv(x)
54
+
55
+
56
+ class ConvolutionModule(nn.Module):
57
+ """
58
+ ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer.
59
+
60
+ Args:
61
+ - input_dim (int): Dimensionality of the input features.
62
+ - hidden_dim (int): Dimensionality of the hidden features.
63
+ - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
64
+ - bias (bool, optional): If True, adds a learnable bias to the output. Default is False.
65
+
66
+ Shapes:
67
+ - Input: (B, T, D) where
68
+ B is batch size,
69
+ T is sequence length,
70
+ D is input dimensionality.
71
+ - Output: (B, T, D) where
72
+ B is batch size,
73
+ T is sequence length,
74
+ D is input dimensionality.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ input_dim: int,
80
+ hidden_dim: int,
81
+ kernel_sizes: List[int],
82
+ bias: bool = False,
83
+ ) -> None:
84
+ """
85
+ Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias.
86
+ """
87
+ super().__init__()
88
+ self.sequential = nn.Sequential(
89
+ nn.GroupNorm(num_groups=1, num_channels=input_dim),
90
+ nn.Conv1d(
91
+ input_dim,
92
+ 2 * hidden_dim,
93
+ kernel_sizes[0],
94
+ stride=1,
95
+ padding=(kernel_sizes[0] - 1) // 2,
96
+ bias=bias,
97
+ ),
98
+ nn.GLU(dim=1),
99
+ nn.Conv1d(
100
+ hidden_dim,
101
+ hidden_dim,
102
+ kernel_sizes[1],
103
+ stride=1,
104
+ padding=(kernel_sizes[1] - 1) // 2,
105
+ groups=hidden_dim,
106
+ bias=bias,
107
+ ),
108
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim),
109
+ nn.SiLU(),
110
+ nn.Conv1d(
111
+ hidden_dim,
112
+ input_dim,
113
+ kernel_sizes[2],
114
+ stride=1,
115
+ padding=(kernel_sizes[2] - 1) // 2,
116
+ bias=bias,
117
+ ),
118
+ )
119
+
120
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
121
+ """
122
+ Performs forward pass through the ConvolutionModule.
123
+
124
+ Args:
125
+ - x (torch.Tensor): Input tensor of shape (B, T, D).
126
+
127
+ Returns:
128
+ - torch.Tensor: Output tensor of shape (B, T, D).
129
+ """
130
+ x = x.transpose(1, 2)
131
+ x = x + self.sequential(x)
132
+ x = x.transpose(1, 2)
133
+ return x
134
+
135
+
136
+ class SDLayer(nn.Module):
137
+ """
138
+ SDLayer class implements a subband decomposition layer with downsampling and convolutional modules.
139
+
140
+ Args:
141
+ - subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition.
142
+ - input_dim (int): Dimensionality of the input channels.
143
+ - output_dim (int): Dimensionality of the output channels after downsampling.
144
+ - downsample_stride (int): Stride value for the downsampling operation.
145
+ - n_conv_modules (int): Number of convolutional modules.
146
+ - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
147
+ - bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True.
148
+
149
+ Shapes:
150
+ - Input: (B, Fi, T, Ci) where
151
+ B is batch size,
152
+ Fi is the number of input subbands,
153
+ T is sequence length, and
154
+ Ci is the number of input channels.
155
+ - Output: (B, Fi+1, T, Ci+1) where
156
+ B is batch size,
157
+ Fi+1 is the number of output subbands,
158
+ T is sequence length,
159
+ Ci+1 is the number of output channels.
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ subband_interval: Tuple[float, float],
165
+ input_dim: int,
166
+ output_dim: int,
167
+ downsample_stride: int,
168
+ n_conv_modules: int,
169
+ kernel_sizes: List[int],
170
+ bias: bool = True,
171
+ ):
172
+ """
173
+ Initializes SDLayer with subband interval, input dimension,
174
+ output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias.
175
+ """
176
+ super().__init__()
177
+ self.subband_interval = subband_interval
178
+ self.downsample = Downsample(input_dim, output_dim, downsample_stride)
179
+ self.activation = nn.GELU()
180
+ conv_modules = [
181
+ ConvolutionModule(
182
+ input_dim=output_dim,
183
+ hidden_dim=output_dim // 4,
184
+ kernel_sizes=kernel_sizes,
185
+ bias=bias,
186
+ )
187
+ for _ in range(n_conv_modules)
188
+ ]
189
+ self.conv_modules = nn.Sequential(*conv_modules)
190
+
191
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
192
+ """
193
+ Performs forward pass through the SDLayer.
194
+
195
+ Args:
196
+ - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
197
+
198
+ Returns:
199
+ - torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1).
200
+ """
201
+ B, F, T, C = x.shape
202
+ x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)]
203
+ x = x.permute(0, 3, 1, 2)
204
+ x = self.downsample(x)
205
+ x = self.activation(x)
206
+ x = x.permute(0, 2, 3, 1)
207
+
208
+ B, F, T, C = x.shape
209
+ x = x.reshape((B * F), T, C)
210
+ x = self.conv_modules(x)
211
+ x = x.reshape(B, F, T, C)
212
+
213
+ return x
214
+
215
+
216
+ class SDBlock(nn.Module):
217
+ """
218
+ SDBlock class implements a block with subband decomposition layers and global convolution.
219
+
220
+ Args:
221
+ - input_dim (int): Dimensionality of the input channels.
222
+ - output_dim (int): Dimensionality of the output channels.
223
+ - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
224
+ - downsample_strides (List[int]): List of stride values for downsampling in each subband layer.
225
+ - n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer.
226
+ - kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None.
227
+
228
+ Shapes:
229
+ - Input: (B, Fi, T, Ci) where
230
+ B is batch size,
231
+ Fi is the number of input subbands,
232
+ T is sequence length,
233
+ Ci is the number of input channels.
234
+ - Output: (B, Fi+1, T, Ci+1) where
235
+ B is batch size,
236
+ Fi+1 is the number of output subbands,
237
+ T is sequence length,
238
+ Ci+1 is the number of output channels.
239
+ """
240
+
241
+ def __init__(
242
+ self,
243
+ input_dim: int,
244
+ output_dim: int,
245
+ bandsplit_ratios: List[float],
246
+ downsample_strides: List[int],
247
+ n_conv_modules: List[int],
248
+ kernel_sizes: List[int] = None,
249
+ ):
250
+ """
251
+ Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes.
252
+ """
253
+ super().__init__()
254
+ if kernel_sizes is None:
255
+ kernel_sizes = [3, 3, 1]
256
+ assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1."
257
+ subband_intervals = create_intervals(bandsplit_ratios)
258
+ self.sd_layers = nn.ModuleList(
259
+ SDLayer(
260
+ input_dim=input_dim,
261
+ output_dim=output_dim,
262
+ subband_interval=sbi,
263
+ downsample_stride=dss,
264
+ n_conv_modules=ncm,
265
+ kernel_sizes=kernel_sizes,
266
+ )
267
+ for sbi, dss, ncm in zip(
268
+ subband_intervals, downsample_strides, n_conv_modules
269
+ )
270
+ )
271
+ self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1)
272
+
273
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
274
+ """
275
+ Performs forward pass through the SDBlock.
276
+
277
+ Args:
278
+ - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
279
+
280
+ Returns:
281
+ - Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor.
282
+ """
283
+ x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1)
284
+ x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
285
+ return x, x_skip