jone commited on
Commit
75c6e9a
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +29 -0
  2. LICENSE +13 -0
  3. README.md +33 -0
  4. app.py +40 -0
  5. bytesep/__init__.py +1 -0
  6. bytesep/callbacks/__init__.py +76 -0
  7. bytesep/callbacks/base_callbacks.py +44 -0
  8. bytesep/callbacks/instruments_callbacks.py +200 -0
  9. bytesep/callbacks/musdb18.py +485 -0
  10. bytesep/callbacks/voicebank_demand.py +231 -0
  11. bytesep/data/__init__.py +0 -0
  12. bytesep/data/augmentors.py +157 -0
  13. bytesep/data/batch_data_preprocessors.py +141 -0
  14. bytesep/data/data_modules.py +187 -0
  15. bytesep/data/samplers.py +188 -0
  16. bytesep/dataset_creation/__init__.py +0 -0
  17. bytesep/dataset_creation/create_evaluation_audios/__init__.py +0 -0
  18. bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py +160 -0
  19. bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py +164 -0
  20. bytesep/dataset_creation/create_evaluation_audios/violin-piano.py +162 -0
  21. bytesep/dataset_creation/create_indexes/__init__.py +0 -0
  22. bytesep/dataset_creation/create_indexes/create_indexes.py +142 -0
  23. bytesep/dataset_creation/pack_audios_to_hdf5s/__init__.py +0 -0
  24. bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py +173 -0
  25. bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py +136 -0
  26. bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py +207 -0
  27. bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py +114 -0
  28. bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py +143 -0
  29. bytesep/inference.py +404 -0
  30. bytesep/inference_many.py +163 -0
  31. bytesep/losses.py +106 -0
  32. bytesep/models/__init__.py +0 -0
  33. bytesep/models/conditional_unet.py +496 -0
  34. bytesep/models/lightning_modules.py +188 -0
  35. bytesep/models/pytorch_modules.py +204 -0
  36. bytesep/models/resunet.py +516 -0
  37. bytesep/models/resunet_ismir2021.py +534 -0
  38. bytesep/models/resunet_subbandtime.py +545 -0
  39. bytesep/models/subband_tools/__init__.py +0 -0
  40. bytesep/models/subband_tools/fDomainHelper.py +255 -0
  41. bytesep/models/subband_tools/filters/f_4_64.mat +0 -0
  42. bytesep/models/subband_tools/filters/h_4_64.mat +0 -0
  43. bytesep/models/subband_tools/pqmf.py +136 -0
  44. bytesep/models/unet.py +532 -0
  45. bytesep/models/unet_subbandtime.py +389 -0
  46. bytesep/optimizers/__init__.py +0 -0
  47. bytesep/optimizers/lr_schedulers.py +20 -0
  48. bytesep/plot_results/__init__.py +0 -0
  49. bytesep/plot_results/musdb18.py +198 -0
  50. bytesep/plot_results/plot_vctk-musdb18.py +87 -0
.gitattributes ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.wav filter=lfs diff=lfs merge=lfs -text
29
+ example.wav filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2021 ByteDance
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Music_Source_Separation
3
+ emoji: ⚡
4
+ colorFrom: green
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # Configuration
12
+
13
+ `title`: _string_
14
+ Display title for the Space
15
+
16
+ `emoji`: _string_
17
+ Space emoji (emoji-only character allowed)
18
+
19
+ `colorFrom`: _string_
20
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
+
22
+ `colorTo`: _string_
23
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
+
25
+ `sdk`: _string_
26
+ Can be either `gradio` or `streamlit`
27
+
28
+ `app_file`: _string_
29
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
30
+ Path is relative to the root of the repository.
31
+
32
+ `pinned`: _boolean_
33
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install gradio==2.3.0a0')
3
+ os.system('pip freeze')
4
+ import sys
5
+ sys.path.append('.')
6
+ import gradio as gr
7
+ os.system('pip install -U torchtext==0.8.0')
8
+ #os.system('python setup.py install --install-dir .')
9
+ from scipy.io import wavfile
10
+
11
+ os.system('./separate_scripts/download_checkpoints.sh')
12
+
13
+ def inference(audio):
14
+ # read the file and get the sample rate and data
15
+ rate, data = wavfile.read(audio.name)
16
+
17
+ # save the result
18
+ wavfile.write('foo_left.wav', rate, data)
19
+ os.system("""python bytesep/inference.py --config_yaml=./scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_subbandtime.yaml --checkpoint_path=./downloaded_checkpoints/resunet143_subbtandtime_vocals_8.8dB_350k_steps.pth --audio_path=foo_left.wav --output_path=sep_vocals.mp3""")
20
+ #os.system('./separate_scripts/separate_vocals.sh ' + audio.name + ' "sep_vocals.mp3"')
21
+ os.system("""python bytesep/inference.py --config_yaml=./scripts/4_train/musdb18/configs/accompaniment-vocals,resunet_subbandtime.yaml --checkpoint_path=./downloaded_checkpoints/resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth --audio_path=foo_left.wav --output_path=sep_accompaniment.mp3""")
22
+ #os.system('./separate_scripts/separate_accompaniment.sh ' + audio.name + ' "sep_accompaniment.mp3"')
23
+ #os.system('python separate_scripts/separate.py --audio_path=' +audio.name+' --source_type="accompaniment"')
24
+ #os.system('python separate_scripts/separate.py --audio_path=' +audio.name+' --source_type="vocals"')
25
+ return 'sep_vocals.mp3', 'sep_accompaniment.mp3'
26
+ title = "Music Source Separation"
27
+ description = "Gradio demo for Music Source Separation. To use it, simply add your audio, or click one of the examples to load them. Currently supports .wav files. Read more at the links below."
28
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.05418'>Decoupling Magnitude and Phase Estimation with Deep ResUNet for Music Source Separation</a> | <a href='https://github.com/bytedance/music_source_separation'>Github Repo</a></p>"
29
+
30
+ examples = [['example.wav']]
31
+ gr.Interface(
32
+ inference,
33
+ gr.inputs.Audio(type="file", label="Input"),
34
+ [gr.outputs.Audio(type="file", label="Vocals"),gr.outputs.Audio(type="file", label="Accompaniment")],
35
+ title=title,
36
+ description=description,
37
+ article=article,
38
+ enable_queue=True,
39
+ examples=examples
40
+ ).launch(debug=True)
bytesep/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from bytesep.inference import Separator
bytesep/callbacks/__init__.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import pytorch_lightning as pl
4
+ import torch.nn as nn
5
+
6
+
7
+ def get_callbacks(
8
+ task_name: str,
9
+ config_yaml: str,
10
+ workspace: str,
11
+ checkpoints_dir: str,
12
+ statistics_path: str,
13
+ logger: pl.loggers.TensorBoardLogger,
14
+ model: nn.Module,
15
+ evaluate_device: str,
16
+ ) -> List[pl.Callback]:
17
+ r"""Get callbacks of a task and config yaml file.
18
+
19
+ Args:
20
+ task_name: str
21
+ config_yaml: str
22
+ dataset_dir: str
23
+ workspace: str, containing useful files such as audios for evaluation
24
+ checkpoints_dir: str, directory to save checkpoints
25
+ statistics_dir: str, directory to save statistics
26
+ logger: pl.loggers.TensorBoardLogger
27
+ model: nn.Module
28
+ evaluate_device: str
29
+
30
+ Return:
31
+ callbacks: List[pl.Callback]
32
+ """
33
+ if task_name == 'musdb18':
34
+
35
+ from bytesep.callbacks.musdb18 import get_musdb18_callbacks
36
+
37
+ return get_musdb18_callbacks(
38
+ config_yaml=config_yaml,
39
+ workspace=workspace,
40
+ checkpoints_dir=checkpoints_dir,
41
+ statistics_path=statistics_path,
42
+ logger=logger,
43
+ model=model,
44
+ evaluate_device=evaluate_device,
45
+ )
46
+
47
+ elif task_name == 'voicebank-demand':
48
+
49
+ from bytesep.callbacks.voicebank_demand import get_voicebank_demand_callbacks
50
+
51
+ return get_voicebank_demand_callbacks(
52
+ config_yaml=config_yaml,
53
+ workspace=workspace,
54
+ checkpoints_dir=checkpoints_dir,
55
+ statistics_path=statistics_path,
56
+ logger=logger,
57
+ model=model,
58
+ evaluate_device=evaluate_device,
59
+ )
60
+
61
+ elif task_name in ['vctk-musdb18', 'violin-piano', 'piano-symphony']:
62
+
63
+ from bytesep.callbacks.instruments_callbacks import get_instruments_callbacks
64
+
65
+ return get_instruments_callbacks(
66
+ config_yaml=config_yaml,
67
+ workspace=workspace,
68
+ checkpoints_dir=checkpoints_dir,
69
+ statistics_path=statistics_path,
70
+ logger=logger,
71
+ model=model,
72
+ evaluate_device=evaluate_device,
73
+ )
74
+
75
+ else:
76
+ raise NotImplementedError
bytesep/callbacks/base_callbacks.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import NoReturn
4
+
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ import torch.nn as nn
8
+ from pytorch_lightning.utilities import rank_zero_only
9
+
10
+
11
+ class SaveCheckpointsCallback(pl.Callback):
12
+ def __init__(
13
+ self,
14
+ model: nn.Module,
15
+ checkpoints_dir: str,
16
+ save_step_frequency: int,
17
+ ):
18
+ r"""Callback to save checkpoints every #save_step_frequency steps.
19
+
20
+ Args:
21
+ model: nn.Module
22
+ checkpoints_dir: str, directory to save checkpoints
23
+ save_step_frequency: int
24
+ """
25
+ self.model = model
26
+ self.checkpoints_dir = checkpoints_dir
27
+ self.save_step_frequency = save_step_frequency
28
+ os.makedirs(self.checkpoints_dir, exist_ok=True)
29
+
30
+ @rank_zero_only
31
+ def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
32
+ r"""Save checkpoint."""
33
+ global_step = trainer.global_step
34
+
35
+ if global_step % self.save_step_frequency == 0:
36
+
37
+ checkpoint_path = os.path.join(
38
+ self.checkpoints_dir, "step={}.pth".format(global_step)
39
+ )
40
+
41
+ checkpoint = {'step': global_step, 'model': self.model.state_dict()}
42
+
43
+ torch.save(checkpoint, checkpoint_path)
44
+ logging.info("Save checkpoint to {}".format(checkpoint_path))
bytesep/callbacks/instruments_callbacks.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from typing import List, NoReturn
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import pytorch_lightning as pl
9
+ import torch.nn as nn
10
+ from pytorch_lightning.utilities import rank_zero_only
11
+
12
+ from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
13
+ from bytesep.inference import Separator
14
+ from bytesep.utils import StatisticsContainer, calculate_sdr, read_yaml
15
+
16
+
17
+ def get_instruments_callbacks(
18
+ config_yaml: str,
19
+ workspace: str,
20
+ checkpoints_dir: str,
21
+ statistics_path: str,
22
+ logger: pl.loggers.TensorBoardLogger,
23
+ model: nn.Module,
24
+ evaluate_device: str,
25
+ ) -> List[pl.Callback]:
26
+ """Get Voicebank-Demand callbacks of a config yaml.
27
+
28
+ Args:
29
+ config_yaml: str
30
+ workspace: str
31
+ checkpoints_dir: str, directory to save checkpoints
32
+ statistics_dir: str, directory to save statistics
33
+ logger: pl.loggers.TensorBoardLogger
34
+ model: nn.Module
35
+ evaluate_device: str
36
+
37
+ Return:
38
+ callbacks: List[pl.Callback]
39
+ """
40
+ configs = read_yaml(config_yaml)
41
+ task_name = configs['task_name']
42
+ target_source_types = configs['train']['target_source_types']
43
+ input_channels = configs['train']['channels']
44
+ mono = True if input_channels == 1 else False
45
+ test_audios_dir = os.path.join(workspace, "evaluation_audios", task_name, "test")
46
+ sample_rate = configs['train']['sample_rate']
47
+ evaluate_step_frequency = configs['train']['evaluate_step_frequency']
48
+ save_step_frequency = configs['train']['save_step_frequency']
49
+ test_batch_size = configs['evaluate']['batch_size']
50
+ test_segment_seconds = configs['evaluate']['segment_seconds']
51
+
52
+ test_segment_samples = int(test_segment_seconds * sample_rate)
53
+ assert len(target_source_types) == 1
54
+ target_source_type = target_source_types[0]
55
+
56
+ # save checkpoint callback
57
+ save_checkpoints_callback = SaveCheckpointsCallback(
58
+ model=model,
59
+ checkpoints_dir=checkpoints_dir,
60
+ save_step_frequency=save_step_frequency,
61
+ )
62
+
63
+ # statistics container
64
+ statistics_container = StatisticsContainer(statistics_path)
65
+
66
+ # evaluation callback
67
+ evaluate_test_callback = EvaluationCallback(
68
+ model=model,
69
+ target_source_type=target_source_type,
70
+ input_channels=input_channels,
71
+ sample_rate=sample_rate,
72
+ mono=mono,
73
+ evaluation_audios_dir=test_audios_dir,
74
+ segment_samples=test_segment_samples,
75
+ batch_size=test_batch_size,
76
+ device=evaluate_device,
77
+ evaluate_step_frequency=evaluate_step_frequency,
78
+ logger=logger,
79
+ statistics_container=statistics_container,
80
+ )
81
+
82
+ callbacks = [save_checkpoints_callback, evaluate_test_callback]
83
+ # callbacks = [save_checkpoints_callback]
84
+
85
+ return callbacks
86
+
87
+
88
+ class EvaluationCallback(pl.Callback):
89
+ def __init__(
90
+ self,
91
+ model: nn.Module,
92
+ input_channels: int,
93
+ evaluation_audios_dir: str,
94
+ target_source_type: str,
95
+ sample_rate: int,
96
+ mono: bool,
97
+ segment_samples: int,
98
+ batch_size: int,
99
+ device: str,
100
+ evaluate_step_frequency: int,
101
+ logger: pl.loggers.TensorBoardLogger,
102
+ statistics_container: StatisticsContainer,
103
+ ):
104
+ r"""Callback to evaluate every #save_step_frequency steps.
105
+
106
+ Args:
107
+ model: nn.Module
108
+ input_channels: int
109
+ evaluation_audios_dir: str, directory containing audios for evaluation
110
+ target_source_type: str, e.g., 'violin'
111
+ sample_rate: int
112
+ mono: bool
113
+ segment_samples: int, length of segments to be input to a model, e.g., 44100*30
114
+ batch_size, int, e.g., 12
115
+ device: str, e.g., 'cuda'
116
+ evaluate_step_frequency: int, evaluate every #save_step_frequency steps
117
+ logger: pl.loggers.TensorBoardLogger
118
+ statistics_container: StatisticsContainer
119
+ """
120
+ self.model = model
121
+ self.target_source_type = target_source_type
122
+ self.sample_rate = sample_rate
123
+ self.mono = mono
124
+ self.segment_samples = segment_samples
125
+ self.evaluate_step_frequency = evaluate_step_frequency
126
+ self.logger = logger
127
+ self.statistics_container = statistics_container
128
+
129
+ self.evaluation_audios_dir = evaluation_audios_dir
130
+
131
+ # separator
132
+ self.separator = Separator(model, self.segment_samples, batch_size, device)
133
+
134
+ @rank_zero_only
135
+ def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
136
+ r"""Evaluate losses on a few mini-batches. Losses are only used for
137
+ observing training, and are not final F1 metrics.
138
+ """
139
+
140
+ global_step = trainer.global_step
141
+
142
+ if global_step % self.evaluate_step_frequency == 0:
143
+
144
+ mixture_audios_dir = os.path.join(self.evaluation_audios_dir, 'mixture')
145
+ clean_audios_dir = os.path.join(
146
+ self.evaluation_audios_dir, self.target_source_type
147
+ )
148
+
149
+ audio_names = sorted(os.listdir(mixture_audios_dir))
150
+
151
+ error_str = "Directory {} does not contain audios for evaluation!".format(
152
+ self.evaluation_audios_dir
153
+ )
154
+ assert len(audio_names) > 0, error_str
155
+
156
+ logging.info("--- Step {} ---".format(global_step))
157
+ logging.info("Total {} pieces for evaluation:".format(len(audio_names)))
158
+
159
+ eval_time = time.time()
160
+
161
+ sdrs = []
162
+
163
+ for n, audio_name in enumerate(audio_names):
164
+
165
+ # Load audio.
166
+ mixture_path = os.path.join(mixture_audios_dir, audio_name)
167
+ clean_path = os.path.join(clean_audios_dir, audio_name)
168
+
169
+ mixture, origin_fs = librosa.core.load(
170
+ mixture_path, sr=self.sample_rate, mono=self.mono
171
+ )
172
+
173
+ # Target
174
+ clean, origin_fs = librosa.core.load(
175
+ clean_path, sr=self.sample_rate, mono=self.mono
176
+ )
177
+
178
+ if mixture.ndim == 1:
179
+ mixture = mixture[None, :]
180
+ # (channels_num, audio_length)
181
+
182
+ input_dict = {'waveform': mixture}
183
+
184
+ # separate
185
+ sep_wav = self.separator.separate(input_dict)
186
+ # (channels_num, audio_length)
187
+
188
+ sdr = calculate_sdr(ref=clean, est=sep_wav)
189
+
190
+ print("{} SDR: {:.3f}".format(audio_name, sdr))
191
+ sdrs.append(sdr)
192
+
193
+ logging.info("-----------------------------")
194
+ logging.info('Avg SDR: {:.3f}'.format(np.mean(sdrs)))
195
+
196
+ logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
197
+
198
+ statistics = {"sdr": np.mean(sdrs)}
199
+ self.statistics_container.append(global_step, statistics, 'test')
200
+ self.statistics_container.dump()
bytesep/callbacks/musdb18.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from typing import Dict, List, NoReturn
5
+
6
+ import librosa
7
+ import musdb
8
+ import museval
9
+ import numpy as np
10
+ import pytorch_lightning as pl
11
+ import torch.nn as nn
12
+ from pytorch_lightning.utilities import rank_zero_only
13
+
14
+ from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
15
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.musdb18 import preprocess_audio
16
+ from bytesep.inference import Separator
17
+ from bytesep.utils import StatisticsContainer, read_yaml
18
+
19
+
20
+ def get_musdb18_callbacks(
21
+ config_yaml: str,
22
+ workspace: str,
23
+ checkpoints_dir: str,
24
+ statistics_path: str,
25
+ logger: pl.loggers.TensorBoardLogger,
26
+ model: nn.Module,
27
+ evaluate_device: str,
28
+ ) -> List[pl.Callback]:
29
+ r"""Get MUSDB18 callbacks of a config yaml.
30
+
31
+ Args:
32
+ config_yaml: str
33
+ workspace: str
34
+ checkpoints_dir: str, directory to save checkpoints
35
+ statistics_dir: str, directory to save statistics
36
+ logger: pl.loggers.TensorBoardLogger
37
+ model: nn.Module
38
+ evaluate_device: str
39
+
40
+ Return:
41
+ callbacks: List[pl.Callback]
42
+ """
43
+ configs = read_yaml(config_yaml)
44
+ task_name = configs['task_name']
45
+ evaluation_callback = configs['train']['evaluation_callback']
46
+ target_source_types = configs['train']['target_source_types']
47
+ input_channels = configs['train']['channels']
48
+ evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name)
49
+ test_segment_seconds = configs['evaluate']['segment_seconds']
50
+ sample_rate = configs['train']['sample_rate']
51
+ test_segment_samples = int(test_segment_seconds * sample_rate)
52
+ test_batch_size = configs['evaluate']['batch_size']
53
+
54
+ evaluate_step_frequency = configs['train']['evaluate_step_frequency']
55
+ save_step_frequency = configs['train']['save_step_frequency']
56
+
57
+ # save checkpoint callback
58
+ save_checkpoints_callback = SaveCheckpointsCallback(
59
+ model=model,
60
+ checkpoints_dir=checkpoints_dir,
61
+ save_step_frequency=save_step_frequency,
62
+ )
63
+
64
+ # evaluation callback
65
+ EvaluationCallback = _get_evaluation_callback_class(evaluation_callback)
66
+
67
+ # statistics container
68
+ statistics_container = StatisticsContainer(statistics_path)
69
+
70
+ # evaluation callback
71
+ evaluate_train_callback = EvaluationCallback(
72
+ dataset_dir=evaluation_audios_dir,
73
+ model=model,
74
+ target_source_types=target_source_types,
75
+ input_channels=input_channels,
76
+ sample_rate=sample_rate,
77
+ split='train',
78
+ segment_samples=test_segment_samples,
79
+ batch_size=test_batch_size,
80
+ device=evaluate_device,
81
+ evaluate_step_frequency=evaluate_step_frequency,
82
+ logger=logger,
83
+ statistics_container=statistics_container,
84
+ )
85
+
86
+ evaluate_test_callback = EvaluationCallback(
87
+ dataset_dir=evaluation_audios_dir,
88
+ model=model,
89
+ target_source_types=target_source_types,
90
+ input_channels=input_channels,
91
+ sample_rate=sample_rate,
92
+ split='test',
93
+ segment_samples=test_segment_samples,
94
+ batch_size=test_batch_size,
95
+ device=evaluate_device,
96
+ evaluate_step_frequency=evaluate_step_frequency,
97
+ logger=logger,
98
+ statistics_container=statistics_container,
99
+ )
100
+
101
+ # callbacks = [save_checkpoints_callback, evaluate_train_callback, evaluate_test_callback]
102
+ callbacks = [save_checkpoints_callback, evaluate_test_callback]
103
+
104
+ return callbacks
105
+
106
+
107
+ def _get_evaluation_callback_class(evaluation_callback) -> pl.Callback:
108
+ r"""Get evaluation callback class."""
109
+ if evaluation_callback == "Musdb18EvaluationCallback":
110
+ return Musdb18EvaluationCallback
111
+
112
+ if evaluation_callback == 'Musdb18ConditionalEvaluationCallback':
113
+ return Musdb18ConditionalEvaluationCallback
114
+
115
+ else:
116
+ raise NotImplementedError
117
+
118
+
119
+ class Musdb18EvaluationCallback(pl.Callback):
120
+ def __init__(
121
+ self,
122
+ dataset_dir: str,
123
+ model: nn.Module,
124
+ target_source_types: str,
125
+ input_channels: int,
126
+ split: str,
127
+ sample_rate: int,
128
+ segment_samples: int,
129
+ batch_size: int,
130
+ device: str,
131
+ evaluate_step_frequency: int,
132
+ logger: pl.loggers.TensorBoardLogger,
133
+ statistics_container: StatisticsContainer,
134
+ ):
135
+ r"""Callback to evaluate every #save_step_frequency steps.
136
+
137
+ Args:
138
+ dataset_dir: str
139
+ model: nn.Module
140
+ target_source_types: List[str], e.g., ['vocals', 'bass', ...]
141
+ input_channels: int
142
+ split: 'train' | 'test'
143
+ sample_rate: int
144
+ segment_samples: int, length of segments to be input to a model, e.g., 44100*30
145
+ batch_size, int, e.g., 12
146
+ device: str, e.g., 'cuda'
147
+ evaluate_step_frequency: int, evaluate every #save_step_frequency steps
148
+ logger: object
149
+ statistics_container: StatisticsContainer
150
+ """
151
+ self.model = model
152
+ self.target_source_types = target_source_types
153
+ self.input_channels = input_channels
154
+ self.sample_rate = sample_rate
155
+ self.split = split
156
+ self.segment_samples = segment_samples
157
+ self.evaluate_step_frequency = evaluate_step_frequency
158
+ self.logger = logger
159
+ self.statistics_container = statistics_container
160
+ self.mono = input_channels == 1
161
+ self.resample_type = "kaiser_fast"
162
+
163
+ self.mus = musdb.DB(root=dataset_dir, subsets=[split])
164
+
165
+ error_msg = "The directory {} is empty!".format(dataset_dir)
166
+ assert len(self.mus) > 0, error_msg
167
+
168
+ # separator
169
+ self.separator = Separator(model, self.segment_samples, batch_size, device)
170
+
171
+ @rank_zero_only
172
+ def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
173
+ r"""Evaluate separation SDRs of audio recordings."""
174
+ global_step = trainer.global_step
175
+
176
+ if global_step % self.evaluate_step_frequency == 0:
177
+
178
+ sdr_dict = {}
179
+
180
+ logging.info("--- Step {} ---".format(global_step))
181
+ logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))
182
+
183
+ eval_time = time.time()
184
+
185
+ for track in self.mus.tracks:
186
+
187
+ audio_name = track.name
188
+
189
+ # Get waveform of mixture.
190
+ mixture = track.audio.T
191
+ # (channels_num, audio_samples)
192
+
193
+ mixture = preprocess_audio(
194
+ audio=mixture,
195
+ mono=self.mono,
196
+ origin_sr=track.rate,
197
+ sr=self.sample_rate,
198
+ resample_type=self.resample_type,
199
+ )
200
+ # (channels_num, audio_samples)
201
+
202
+ target_dict = {}
203
+ sdr_dict[audio_name] = {}
204
+
205
+ # Get waveform of all target source types.
206
+ for j, source_type in enumerate(self.target_source_types):
207
+ # E.g., ['vocals', 'bass', ...]
208
+
209
+ audio = track.targets[source_type].audio.T
210
+
211
+ audio = preprocess_audio(
212
+ audio=audio,
213
+ mono=self.mono,
214
+ origin_sr=track.rate,
215
+ sr=self.sample_rate,
216
+ resample_type=self.resample_type,
217
+ )
218
+ # (channels_num, audio_samples)
219
+
220
+ target_dict[source_type] = audio
221
+ # (channels_num, audio_samples)
222
+
223
+ # Separate.
224
+ input_dict = {'waveform': mixture}
225
+
226
+ sep_wavs = self.separator.separate(input_dict)
227
+ # sep_wavs: (target_sources_num * channels_num, audio_samples)
228
+
229
+ # Post process separation results.
230
+ sep_wavs = preprocess_audio(
231
+ audio=sep_wavs,
232
+ mono=self.mono,
233
+ origin_sr=self.sample_rate,
234
+ sr=track.rate,
235
+ resample_type=self.resample_type,
236
+ )
237
+ # sep_wavs: (target_sources_num * channels_num, audio_samples)
238
+
239
+ sep_wavs = librosa.util.fix_length(
240
+ sep_wavs, size=mixture.shape[1], axis=1
241
+ )
242
+ # sep_wavs: (target_sources_num * channels_num, audio_samples)
243
+
244
+ sep_wav_dict = get_separated_wavs_from_simo_output(
245
+ sep_wavs, self.input_channels, self.target_source_types
246
+ )
247
+ # output_dict: dict, e.g., {
248
+ # 'vocals': (channels_num, audio_samples),
249
+ # 'bass': (channels_num, audio_samples),
250
+ # ...,
251
+ # }
252
+
253
+ # Evaluate for all target source types.
254
+ for source_type in self.target_source_types:
255
+ # E.g., ['vocals', 'bass', ...]
256
+
257
+ # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan).
258
+ (sdrs, _, _, _) = museval.evaluate(
259
+ [target_dict[source_type].T], [sep_wav_dict[source_type].T]
260
+ )
261
+
262
+ sdr = np.nanmedian(sdrs)
263
+ sdr_dict[audio_name][source_type] = sdr
264
+
265
+ logging.info(
266
+ "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
267
+ )
268
+
269
+ logging.info("-----------------------------")
270
+ median_sdr_dict = {}
271
+
272
+ # Calculate median SDRs of all songs.
273
+ for source_type in self.target_source_types:
274
+ # E.g., ['vocals', 'bass', ...]
275
+
276
+ median_sdr = np.median(
277
+ [
278
+ sdr_dict[audio_name][source_type]
279
+ for audio_name in sdr_dict.keys()
280
+ ]
281
+ )
282
+
283
+ median_sdr_dict[source_type] = median_sdr
284
+
285
+ logging.info(
286
+ "Step: {}, {}, Median SDR: {:.3f}".format(
287
+ global_step, source_type, median_sdr
288
+ )
289
+ )
290
+
291
+ logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
292
+
293
+ statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
294
+ self.statistics_container.append(global_step, statistics, self.split)
295
+ self.statistics_container.dump()
296
+
297
+
298
+ def get_separated_wavs_from_simo_output(x, input_channels, target_source_types) -> Dict:
299
+ r"""Get separated waveforms of target sources from a single input multiple
300
+ output (SIMO) system.
301
+
302
+ Args:
303
+ x: (target_sources_num * channels_num, audio_samples)
304
+ input_channels: int
305
+ target_source_types: List[str], e.g., ['vocals', 'bass', ...]
306
+
307
+ Returns:
308
+ output_dict: dict, e.g., {
309
+ 'vocals': (channels_num, audio_samples),
310
+ 'bass': (channels_num, audio_samples),
311
+ ...,
312
+ }
313
+ """
314
+ output_dict = {}
315
+
316
+ for j, source_type in enumerate(target_source_types):
317
+ output_dict[source_type] = x[j * input_channels : (j + 1) * input_channels]
318
+
319
+ return output_dict
320
+
321
+
322
+ class Musdb18ConditionalEvaluationCallback(pl.Callback):
323
+ def __init__(
324
+ self,
325
+ dataset_dir: str,
326
+ model: nn.Module,
327
+ target_source_types: str,
328
+ input_channels: int,
329
+ split: str,
330
+ sample_rate: int,
331
+ segment_samples: int,
332
+ batch_size: int,
333
+ device: str,
334
+ evaluate_step_frequency: int,
335
+ logger: pl.loggers.TensorBoardLogger,
336
+ statistics_container: StatisticsContainer,
337
+ ):
338
+ r"""Callback to evaluate every #save_step_frequency steps.
339
+
340
+ Args:
341
+ dataset_dir: str
342
+ model: nn.Module
343
+ target_source_types: List[str], e.g., ['vocals', 'bass', ...]
344
+ input_channels: int
345
+ split: 'train' | 'test'
346
+ sample_rate: int
347
+ segment_samples: int, length of segments to be input to a model, e.g., 44100*30
348
+ batch_size, int, e.g., 12
349
+ device: str, e.g., 'cuda'
350
+ evaluate_step_frequency: int, evaluate every #save_step_frequency steps
351
+ logger: object
352
+ statistics_container: StatisticsContainer
353
+ """
354
+ self.model = model
355
+ self.target_source_types = target_source_types
356
+ self.input_channels = input_channels
357
+ self.sample_rate = sample_rate
358
+ self.split = split
359
+ self.segment_samples = segment_samples
360
+ self.evaluate_step_frequency = evaluate_step_frequency
361
+ self.logger = logger
362
+ self.statistics_container = statistics_container
363
+ self.mono = input_channels == 1
364
+ self.resample_type = "kaiser_fast"
365
+
366
+ self.mus = musdb.DB(root=dataset_dir, subsets=[split])
367
+
368
+ error_msg = "The directory {} is empty!".format(dataset_dir)
369
+ assert len(self.mus) > 0, error_msg
370
+
371
+ # separator
372
+ self.separator = Separator(model, self.segment_samples, batch_size, device)
373
+
374
+ @rank_zero_only
375
+ def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
376
+ r"""Evaluate separation SDRs of audio recordings."""
377
+ global_step = trainer.global_step
378
+
379
+ if global_step % self.evaluate_step_frequency == 0:
380
+
381
+ sdr_dict = {}
382
+
383
+ logging.info("--- Step {} ---".format(global_step))
384
+ logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))
385
+
386
+ eval_time = time.time()
387
+
388
+ for track in self.mus.tracks:
389
+
390
+ audio_name = track.name
391
+
392
+ # Get waveform of mixture.
393
+ mixture = track.audio.T
394
+ # (channels_num, audio_samples)
395
+
396
+ mixture = preprocess_audio(
397
+ audio=mixture,
398
+ mono=self.mono,
399
+ origin_sr=track.rate,
400
+ sr=self.sample_rate,
401
+ resample_type=self.resample_type,
402
+ )
403
+ # (channels_num, audio_samples)
404
+
405
+ target_dict = {}
406
+ sdr_dict[audio_name] = {}
407
+
408
+ # Get waveform of all target source types.
409
+ for j, source_type in enumerate(self.target_source_types):
410
+ # E.g., ['vocals', 'bass', ...]
411
+
412
+ audio = track.targets[source_type].audio.T
413
+
414
+ audio = preprocess_audio(
415
+ audio=audio,
416
+ mono=self.mono,
417
+ origin_sr=track.rate,
418
+ sr=self.sample_rate,
419
+ resample_type=self.resample_type,
420
+ )
421
+ # (channels_num, audio_samples)
422
+
423
+ target_dict[source_type] = audio
424
+ # (channels_num, audio_samples)
425
+
426
+ condition = np.zeros(len(self.target_source_types))
427
+ condition[j] = 1
428
+
429
+ input_dict = {'waveform': mixture, 'condition': condition}
430
+
431
+ sep_wav = self.separator.separate(input_dict)
432
+ # sep_wav: (channels_num, audio_samples)
433
+
434
+ sep_wav = preprocess_audio(
435
+ audio=sep_wav,
436
+ mono=self.mono,
437
+ origin_sr=self.sample_rate,
438
+ sr=track.rate,
439
+ resample_type=self.resample_type,
440
+ )
441
+ # sep_wav: (channels_num, audio_samples)
442
+
443
+ sep_wav = librosa.util.fix_length(
444
+ sep_wav, size=mixture.shape[1], axis=1
445
+ )
446
+ # sep_wav: (target_sources_num * channels_num, audio_samples)
447
+
448
+ # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan)
449
+ (sdrs, _, _, _) = museval.evaluate(
450
+ [target_dict[source_type].T], [sep_wav.T]
451
+ )
452
+
453
+ sdr = np.nanmedian(sdrs)
454
+ sdr_dict[audio_name][source_type] = sdr
455
+
456
+ logging.info(
457
+ "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
458
+ )
459
+
460
+ logging.info("-----------------------------")
461
+ median_sdr_dict = {}
462
+
463
+ # Calculate median SDRs of all songs.
464
+ for source_type in self.target_source_types:
465
+
466
+ median_sdr = np.median(
467
+ [
468
+ sdr_dict[audio_name][source_type]
469
+ for audio_name in sdr_dict.keys()
470
+ ]
471
+ )
472
+
473
+ median_sdr_dict[source_type] = median_sdr
474
+
475
+ logging.info(
476
+ "Step: {}, {}, Median SDR: {:.3f}".format(
477
+ global_step, source_type, median_sdr
478
+ )
479
+ )
480
+
481
+ logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
482
+
483
+ statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
484
+ self.statistics_container.append(global_step, statistics, self.split)
485
+ self.statistics_container.dump()
bytesep/callbacks/voicebank_demand.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from typing import List, NoReturn
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import pysepm
9
+ import pytorch_lightning as pl
10
+ import torch.nn as nn
11
+ from pesq import pesq
12
+ from pytorch_lightning.utilities import rank_zero_only
13
+
14
+ from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
15
+ from bytesep.inference import Separator
16
+ from bytesep.utils import StatisticsContainer, read_yaml
17
+
18
+
19
+ def get_voicebank_demand_callbacks(
20
+ config_yaml: str,
21
+ workspace: str,
22
+ checkpoints_dir: str,
23
+ statistics_path: str,
24
+ logger: pl.loggers.TensorBoardLogger,
25
+ model: nn.Module,
26
+ evaluate_device: str,
27
+ ) -> List[pl.Callback]:
28
+ """Get Voicebank-Demand callbacks of a config yaml.
29
+
30
+ Args:
31
+ config_yaml: str
32
+ workspace: str
33
+ checkpoints_dir: str, directory to save checkpoints
34
+ statistics_dir: str, directory to save statistics
35
+ logger: pl.loggers.TensorBoardLogger
36
+ model: nn.Module
37
+ evaluate_device: str
38
+
39
+ Return:
40
+ callbacks: List[pl.Callback]
41
+ """
42
+ configs = read_yaml(config_yaml)
43
+ task_name = configs['task_name']
44
+ target_source_types = configs['train']['target_source_types']
45
+ input_channels = configs['train']['channels']
46
+ evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name)
47
+ sample_rate = configs['train']['sample_rate']
48
+ evaluate_step_frequency = configs['train']['evaluate_step_frequency']
49
+ save_step_frequency = configs['train']['save_step_frequency']
50
+ test_batch_size = configs['evaluate']['batch_size']
51
+ test_segment_seconds = configs['evaluate']['segment_seconds']
52
+
53
+ test_segment_samples = int(test_segment_seconds * sample_rate)
54
+ assert len(target_source_types) == 1
55
+ target_source_type = target_source_types[0]
56
+ assert target_source_type == 'speech'
57
+
58
+ # save checkpoint callback
59
+ save_checkpoints_callback = SaveCheckpointsCallback(
60
+ model=model,
61
+ checkpoints_dir=checkpoints_dir,
62
+ save_step_frequency=save_step_frequency,
63
+ )
64
+
65
+ # statistics container
66
+ statistics_container = StatisticsContainer(statistics_path)
67
+
68
+ # evaluation callback
69
+ evaluate_test_callback = EvaluationCallback(
70
+ model=model,
71
+ input_channels=input_channels,
72
+ sample_rate=sample_rate,
73
+ evaluation_audios_dir=evaluation_audios_dir,
74
+ segment_samples=test_segment_samples,
75
+ batch_size=test_batch_size,
76
+ device=evaluate_device,
77
+ evaluate_step_frequency=evaluate_step_frequency,
78
+ logger=logger,
79
+ statistics_container=statistics_container,
80
+ )
81
+
82
+ callbacks = [save_checkpoints_callback, evaluate_test_callback]
83
+
84
+ return callbacks
85
+
86
+
87
+ class EvaluationCallback(pl.Callback):
88
+ def __init__(
89
+ self,
90
+ model: nn.Module,
91
+ input_channels: int,
92
+ evaluation_audios_dir,
93
+ sample_rate: int,
94
+ segment_samples: int,
95
+ batch_size: int,
96
+ device: str,
97
+ evaluate_step_frequency: int,
98
+ logger: pl.loggers.TensorBoardLogger,
99
+ statistics_container: StatisticsContainer,
100
+ ):
101
+ r"""Callback to evaluate every #save_step_frequency steps.
102
+
103
+ Args:
104
+ model: nn.Module
105
+ input_channels: int
106
+ evaluation_audios_dir: str, directory containing audios for evaluation
107
+ sample_rate: int
108
+ segment_samples: int, length of segments to be input to a model, e.g., 44100*30
109
+ batch_size, int, e.g., 12
110
+ device: str, e.g., 'cuda'
111
+ evaluate_step_frequency: int, evaluate every #save_step_frequency steps
112
+ logger: pl.loggers.TensorBoardLogger
113
+ statistics_container: StatisticsContainer
114
+ """
115
+ self.model = model
116
+ self.mono = True
117
+ self.sample_rate = sample_rate
118
+ self.segment_samples = segment_samples
119
+ self.evaluate_step_frequency = evaluate_step_frequency
120
+ self.logger = logger
121
+ self.statistics_container = statistics_container
122
+
123
+ self.clean_dir = os.path.join(evaluation_audios_dir, "clean_testset_wav")
124
+ self.noisy_dir = os.path.join(evaluation_audios_dir, "noisy_testset_wav")
125
+
126
+ self.EVALUATION_SAMPLE_RATE = 16000 # Evaluation sample rate of the
127
+ # Voicebank-Demand task.
128
+
129
+ # separator
130
+ self.separator = Separator(model, self.segment_samples, batch_size, device)
131
+
132
+ @rank_zero_only
133
+ def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
134
+ r"""Evaluate losses on a few mini-batches. Losses are only used for
135
+ observing training, and are not final F1 metrics.
136
+ """
137
+
138
+ global_step = trainer.global_step
139
+
140
+ if global_step % self.evaluate_step_frequency == 0:
141
+
142
+ audio_names = sorted(
143
+ [
144
+ audio_name
145
+ for audio_name in sorted(os.listdir(self.clean_dir))
146
+ if audio_name.endswith('.wav')
147
+ ]
148
+ )
149
+
150
+ error_str = "Directory {} does not contain audios for evaluation!".format(
151
+ self.clean_dir
152
+ )
153
+ assert len(audio_names) > 0, error_str
154
+
155
+ pesqs, csigs, cbaks, covls, ssnrs = [], [], [], [], []
156
+
157
+ logging.info("--- Step {} ---".format(global_step))
158
+ logging.info("Total {} pieces for evaluation:".format(len(audio_names)))
159
+
160
+ eval_time = time.time()
161
+
162
+ for n, audio_name in enumerate(audio_names):
163
+
164
+ # Load audio.
165
+ clean_path = os.path.join(self.clean_dir, audio_name)
166
+ mixture_path = os.path.join(self.noisy_dir, audio_name)
167
+
168
+ mixture, _ = librosa.core.load(
169
+ mixture_path, sr=self.sample_rate, mono=self.mono
170
+ )
171
+
172
+ if mixture.ndim == 1:
173
+ mixture = mixture[None, :]
174
+ # (channels_num, audio_length)
175
+
176
+ # Separate.
177
+ input_dict = {'waveform': mixture}
178
+
179
+ sep_wav = self.separator.separate(input_dict)
180
+ # (channels_num, audio_length)
181
+
182
+ # Target
183
+ clean, _ = librosa.core.load(
184
+ clean_path, sr=self.EVALUATION_SAMPLE_RATE, mono=self.mono
185
+ )
186
+
187
+ # to mono
188
+ sep_wav = np.squeeze(sep_wav)
189
+
190
+ # Resample for evaluation.
191
+ sep_wav = librosa.resample(
192
+ sep_wav,
193
+ orig_sr=self.sample_rate,
194
+ target_sr=self.EVALUATION_SAMPLE_RATE,
195
+ )
196
+
197
+ sep_wav = librosa.util.fix_length(sep_wav, size=len(clean), axis=0)
198
+ # (channels, audio_length)
199
+
200
+ # Evaluate metrics
201
+ pesq_ = pesq(self.EVALUATION_SAMPLE_RATE, clean, sep_wav, 'wb')
202
+
203
+ (csig, cbak, covl) = pysepm.composite(
204
+ clean, sep_wav, self.EVALUATION_SAMPLE_RATE
205
+ )
206
+
207
+ ssnr = pysepm.SNRseg(clean, sep_wav, self.EVALUATION_SAMPLE_RATE)
208
+
209
+ pesqs.append(pesq_)
210
+ csigs.append(csig)
211
+ cbaks.append(cbak)
212
+ covls.append(covl)
213
+ ssnrs.append(ssnr)
214
+ print(
215
+ '{}, {}, PESQ: {:.3f}, CSIG: {:.3f}, CBAK: {:.3f}, COVL: {:.3f}, SSNR: {:.3f}'.format(
216
+ n, audio_name, pesq_, csig, cbak, covl, ssnr
217
+ )
218
+ )
219
+
220
+ logging.info("-----------------------------")
221
+ logging.info('Avg PESQ: {:.3f}'.format(np.mean(pesqs)))
222
+ logging.info('Avg CSIG: {:.3f}'.format(np.mean(csigs)))
223
+ logging.info('Avg CBAK: {:.3f}'.format(np.mean(cbaks)))
224
+ logging.info('Avg COVL: {:.3f}'.format(np.mean(covls)))
225
+ logging.info('Avg SSNR: {:.3f}'.format(np.mean(ssnrs)))
226
+
227
+ logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
228
+
229
+ statistics = {"pesq": np.mean(pesqs)}
230
+ self.statistics_container.append(global_step, statistics, 'test')
231
+ self.statistics_container.dump()
bytesep/data/__init__.py ADDED
File without changes
bytesep/data/augmentors.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import librosa
4
+ import numpy as np
5
+
6
+ from bytesep.utils import db_to_magnitude, get_pitch_shift_factor, magnitude_to_db
7
+
8
+
9
+ class Augmentor:
10
+ def __init__(self, augmentations: Dict, random_seed=1234):
11
+ r"""Augmentor for data augmentation of a waveform.
12
+
13
+ Args:
14
+ augmentations: Dict, e.g, {
15
+ 'mixaudio': {'vocals': 2, 'accompaniment': 2}
16
+ 'pitch_shift': {'vocals': 4, 'accompaniment': 4},
17
+ ...,
18
+ }
19
+ random_seed: int
20
+ """
21
+ self.augmentations = augmentations
22
+ self.random_state = np.random.RandomState(random_seed)
23
+
24
+ def __call__(self, waveform: np.array, source_type: str) -> np.array:
25
+ r"""Augment a waveform.
26
+
27
+ Args:
28
+ waveform: (channels_num, audio_samples)
29
+ source_type: str
30
+
31
+ Returns:
32
+ new_waveform: (channels_num, new_audio_samples)
33
+ """
34
+ if 'pitch_shift' in self.augmentations.keys():
35
+ waveform = self.pitch_shift(waveform, source_type)
36
+
37
+ if 'magnitude_scale' in self.augmentations.keys():
38
+ waveform = self.magnitude_scale(waveform, source_type)
39
+
40
+ if 'swap_channel' in self.augmentations.keys():
41
+ waveform = self.swap_channel(waveform, source_type)
42
+
43
+ if 'flip_axis' in self.augmentations.keys():
44
+ waveform = self.flip_axis(waveform, source_type)
45
+
46
+ return waveform
47
+
48
+ def pitch_shift(self, waveform: np.array, source_type: str) -> np.array:
49
+ r"""Shift the pitch of a waveform. We use resampling for fast pitch
50
+ shifting, so the speed will also be chaneged. The length of the returned
51
+ waveform will be changed.
52
+
53
+ Args:
54
+ waveform: (channels_num, audio_samples)
55
+ source_type: str
56
+
57
+ Returns:
58
+ new_waveform: (channels_num, new_audio_samples)
59
+ """
60
+
61
+ # maximum pitch shift in semitones
62
+ max_pitch_shift = self.augmentations['pitch_shift'][source_type]
63
+
64
+ if max_pitch_shift == 0: # No pitch shift augmentations.
65
+ return waveform
66
+
67
+ # random pitch shift
68
+ rand_pitch = self.random_state.uniform(
69
+ low=-max_pitch_shift, high=max_pitch_shift
70
+ )
71
+
72
+ # We use librosa.resample instead of librosa.effects.pitch_shift
73
+ # because it is 10x times faster.
74
+ pitch_shift_factor = get_pitch_shift_factor(rand_pitch)
75
+ dummy_sample_rate = 10000 # Dummy constant.
76
+
77
+ channels_num = waveform.shape[0]
78
+
79
+ if channels_num == 1:
80
+ waveform = np.squeeze(waveform)
81
+
82
+ new_waveform = librosa.resample(
83
+ y=waveform,
84
+ orig_sr=dummy_sample_rate,
85
+ target_sr=dummy_sample_rate / pitch_shift_factor,
86
+ res_type='linear',
87
+ axis=-1,
88
+ )
89
+
90
+ if channels_num == 1:
91
+ new_waveform = new_waveform[None, :]
92
+
93
+ return new_waveform
94
+
95
+ def magnitude_scale(self, waveform: np.array, source_type: str) -> np.array:
96
+ r"""Scale the magnitude of a waveform.
97
+
98
+ Args:
99
+ waveform: (channels_num, audio_samples)
100
+ source_type: str
101
+
102
+ Returns:
103
+ new_waveform: (channels_num, audio_samples)
104
+ """
105
+ lower_db = self.augmentations['magnitude_scale'][source_type]['lower_db']
106
+ higher_db = self.augmentations['magnitude_scale'][source_type]['higher_db']
107
+
108
+ if lower_db == 0 and higher_db == 0: # No magnitude scale augmentation.
109
+ return waveform
110
+
111
+ # The magnitude (in dB) of the sample with the maximum value.
112
+ waveform_db = magnitude_to_db(np.max(np.abs(waveform)))
113
+
114
+ new_waveform_db = self.random_state.uniform(
115
+ waveform_db + lower_db, min(waveform_db + higher_db, 0)
116
+ )
117
+
118
+ relative_db = new_waveform_db - waveform_db
119
+
120
+ relative_scale = db_to_magnitude(relative_db)
121
+
122
+ new_waveform = waveform * relative_scale
123
+
124
+ return new_waveform
125
+
126
+ def swap_channel(self, waveform: np.array, source_type: str) -> np.array:
127
+ r"""Randomly swap channels.
128
+
129
+ Args:
130
+ waveform: (channels_num, audio_samples)
131
+ source_type: str
132
+
133
+ Returns:
134
+ new_waveform: (channels_num, audio_samples)
135
+ """
136
+ ndim = waveform.shape[0]
137
+
138
+ if ndim == 1:
139
+ return waveform
140
+ else:
141
+ random_axes = self.random_state.permutation(ndim)
142
+ return waveform[random_axes, :]
143
+
144
+ def flip_axis(self, waveform: np.array, source_type: str) -> np.array:
145
+ r"""Randomly flip the waveform along x-axis.
146
+
147
+ Args:
148
+ waveform: (channels_num, audio_samples)
149
+ source_type: str
150
+
151
+ Returns:
152
+ new_waveform: (channels_num, audio_samples)
153
+ """
154
+ ndim = waveform.shape[0]
155
+ random_values = self.random_state.choice([-1, 1], size=ndim)
156
+
157
+ return waveform * random_values[:, None]
bytesep/data/batch_data_preprocessors.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import torch
4
+
5
+
6
+ class BasicBatchDataPreprocessor:
7
+ def __init__(self, target_source_types: List[str]):
8
+ r"""Batch data preprocessor. Used for preparing mixtures and targets for
9
+ training. If there are multiple target source types, the waveforms of
10
+ those sources will be stacked along the channel dimension.
11
+
12
+ Args:
13
+ target_source_types: List[str], e.g., ['vocals', 'bass', ...]
14
+ """
15
+ self.target_source_types = target_source_types
16
+
17
+ def __call__(self, batch_data_dict: Dict) -> List[Dict]:
18
+ r"""Format waveforms and targets for training.
19
+
20
+ Args:
21
+ batch_data_dict: dict, e.g., {
22
+ 'mixture': (batch_size, channels_num, segment_samples),
23
+ 'vocals': (batch_size, channels_num, segment_samples),
24
+ 'bass': (batch_size, channels_num, segment_samples),
25
+ ...,
26
+ }
27
+
28
+ Returns:
29
+ input_dict: dict, e.g., {
30
+ 'waveform': (batch_size, channels_num, segment_samples),
31
+ }
32
+ output_dict: dict, e.g., {
33
+ 'target': (batch_size, target_sources_num * channels_num, segment_samples)
34
+ }
35
+ """
36
+ mixtures = batch_data_dict['mixture']
37
+ # mixtures: (batch_size, channels_num, segment_samples)
38
+
39
+ # Concatenate waveforms of multiple targets along the channel axis.
40
+ targets = torch.cat(
41
+ [batch_data_dict[source_type] for source_type in self.target_source_types],
42
+ dim=1,
43
+ )
44
+ # targets: (batch_size, target_sources_num * channels_num, segment_samples)
45
+
46
+ input_dict = {'waveform': mixtures}
47
+ target_dict = {'waveform': targets}
48
+
49
+ return input_dict, target_dict
50
+
51
+
52
+ class ConditionalSisoBatchDataPreprocessor:
53
+ def __init__(self, target_source_types: List[str]):
54
+ r"""Conditional single input single output (SISO) batch data
55
+ preprocessor. Select one target source from several target sources as
56
+ training target and prepare the corresponding conditional vector.
57
+
58
+ Args:
59
+ target_source_types: List[str], e.g., ['vocals', 'bass', ...]
60
+ """
61
+ self.target_source_types = target_source_types
62
+
63
+ def __call__(self, batch_data_dict: Dict) -> List[Dict]:
64
+ r"""Format waveforms and targets for training.
65
+
66
+ Args:
67
+ batch_data_dict: dict, e.g., {
68
+ 'mixture': (batch_size, channels_num, segment_samples),
69
+ 'vocals': (batch_size, channels_num, segment_samples),
70
+ 'bass': (batch_size, channels_num, segment_samples),
71
+ ...,
72
+ }
73
+
74
+ Returns:
75
+ input_dict: dict, e.g., {
76
+ 'waveform': (batch_size, channels_num, segment_samples),
77
+ 'condition': (batch_size, target_sources_num),
78
+ }
79
+ output_dict: dict, e.g., {
80
+ 'target': (batch_size, channels_num, segment_samples)
81
+ }
82
+ """
83
+
84
+ batch_size = len(batch_data_dict['mixture'])
85
+ target_sources_num = len(self.target_source_types)
86
+
87
+ assert (
88
+ batch_size % target_sources_num == 0
89
+ ), "Batch size should be \
90
+ evenly divided by target sources number."
91
+
92
+ mixtures = batch_data_dict['mixture']
93
+ # mixtures: (batch_size, channels_num, segment_samples)
94
+
95
+ conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device)
96
+ # conditions: (batch_size, target_sources_num)
97
+
98
+ targets = []
99
+
100
+ for n in range(batch_size):
101
+
102
+ k = n % target_sources_num # source class index
103
+ source_type = self.target_source_types[k]
104
+
105
+ targets.append(batch_data_dict[source_type][n])
106
+
107
+ conditions[n, k] = 1
108
+
109
+ # conditions will looks like:
110
+ # [[1, 0, 0, 0],
111
+ # [0, 1, 0, 0],
112
+ # [0, 0, 1, 0],
113
+ # [0, 0, 0, 1],
114
+ # [1, 0, 0, 0],
115
+ # [0, 1, 0, 0],
116
+ # ...,
117
+ # ]
118
+
119
+ targets = torch.stack(targets, dim=0)
120
+ # targets: (batch_size, channels_num, segment_samples)
121
+
122
+ input_dict = {
123
+ 'waveform': mixtures,
124
+ 'condition': conditions,
125
+ }
126
+
127
+ target_dict = {'waveform': targets}
128
+
129
+ return input_dict, target_dict
130
+
131
+
132
+ def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object:
133
+ r"""Get batch data preprocessor class."""
134
+ if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor':
135
+ return BasicBatchDataPreprocessor
136
+
137
+ elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor':
138
+ return ConditionalSisoBatchDataPreprocessor
139
+
140
+ else:
141
+ raise NotImplementedError
bytesep/data/data_modules.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, NoReturn, Optional
2
+
3
+ import h5py
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from pytorch_lightning.core.datamodule import LightningDataModule
8
+
9
+ from bytesep.data.samplers import DistributedSamplerWrapper
10
+ from bytesep.utils import int16_to_float32
11
+
12
+
13
+ class DataModule(LightningDataModule):
14
+ def __init__(
15
+ self,
16
+ train_sampler: object,
17
+ train_dataset: object,
18
+ num_workers: int,
19
+ distributed: bool,
20
+ ):
21
+ r"""Data module.
22
+
23
+ Args:
24
+ train_sampler: Sampler object
25
+ train_dataset: Dataset object
26
+ num_workers: int
27
+ distributed: bool
28
+ """
29
+ super().__init__()
30
+ self._train_sampler = train_sampler
31
+ self.train_dataset = train_dataset
32
+ self.num_workers = num_workers
33
+ self.distributed = distributed
34
+
35
+ def setup(self, stage: Optional[str] = None) -> NoReturn:
36
+ r"""called on every device."""
37
+
38
+ # SegmentSampler is used for selecting segments for training.
39
+ # On multiple devices, each SegmentSampler samples a part of mini-batch
40
+ # data.
41
+ if self.distributed:
42
+ self.train_sampler = DistributedSamplerWrapper(self._train_sampler)
43
+
44
+ else:
45
+ self.train_sampler = self._train_sampler
46
+
47
+ def train_dataloader(self) -> torch.utils.data.DataLoader:
48
+ r"""Get train loader."""
49
+ train_loader = torch.utils.data.DataLoader(
50
+ dataset=self.train_dataset,
51
+ batch_sampler=self.train_sampler,
52
+ collate_fn=collate_fn,
53
+ num_workers=self.num_workers,
54
+ pin_memory=True,
55
+ )
56
+
57
+ return train_loader
58
+
59
+
60
+ class Dataset:
61
+ def __init__(self, augmentor: object, segment_samples: int):
62
+ r"""Used for getting data according to a meta.
63
+
64
+ Args:
65
+ augmentor: Augmentor class
66
+ segment_samples: int
67
+ """
68
+ self.augmentor = augmentor
69
+ self.segment_samples = segment_samples
70
+
71
+ def __getitem__(self, meta: Dict) -> Dict:
72
+ r"""Return data according to a meta. E.g., an input meta looks like: {
73
+ 'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
74
+ 'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}.
75
+ }
76
+
77
+ Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation).
78
+ Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation).
79
+ Finally, mixture is created by summing vocals and accompaniment.
80
+
81
+ Args:
82
+ meta: dict, e.g., {
83
+ 'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
84
+ 'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}
85
+ }
86
+
87
+ Returns:
88
+ data_dict: dict, e.g., {
89
+ 'vocals': (channels, segments_num),
90
+ 'accompaniment': (channels, segments_num),
91
+ 'mixture': (channels, segments_num),
92
+ }
93
+ """
94
+ source_types = meta.keys()
95
+ data_dict = {}
96
+
97
+ for source_type in source_types:
98
+ # E.g., ['vocals', 'bass', ...]
99
+
100
+ waveforms = [] # Audio segments to be mix-audio augmented.
101
+
102
+ for m in meta[source_type]:
103
+ # E.g., {
104
+ # 'hdf5_path': '.../song_A.h5',
105
+ # 'key_in_hdf5': 'vocals',
106
+ # 'begin_sample': '13406400',
107
+ # 'end_sample': 13538700,
108
+ # }
109
+
110
+ hdf5_path = m['hdf5_path']
111
+ key_in_hdf5 = m['key_in_hdf5']
112
+ bgn_sample = m['begin_sample']
113
+ end_sample = m['end_sample']
114
+
115
+ with h5py.File(hdf5_path, 'r') as hf:
116
+
117
+ if source_type == 'audioset':
118
+ index_in_hdf5 = m['index_in_hdf5']
119
+ waveform = int16_to_float32(
120
+ hf['waveform'][index_in_hdf5][bgn_sample:end_sample]
121
+ )
122
+ waveform = waveform[None, :]
123
+ else:
124
+ waveform = int16_to_float32(
125
+ hf[key_in_hdf5][:, bgn_sample:end_sample]
126
+ )
127
+
128
+ if self.augmentor:
129
+ waveform = self.augmentor(waveform, source_type)
130
+
131
+ waveform = librosa.util.fix_length(
132
+ waveform, size=self.segment_samples, axis=1
133
+ )
134
+ # (channels_num, segments_num)
135
+
136
+ waveforms.append(waveform)
137
+ # E.g., waveforms: [(channels_num, audio_samples), (channels_num, audio_samples)]
138
+
139
+ # mix-audio augmentation
140
+ data_dict[source_type] = np.sum(waveforms, axis=0)
141
+ # data_dict[source_type]: (channels_num, audio_samples)
142
+
143
+ # data_dict looks like: {
144
+ # 'voclas': (channels_num, audio_samples),
145
+ # 'accompaniment': (channels_num, audio_samples)
146
+ # }
147
+
148
+ # Mix segments from different sources.
149
+ mixture = np.sum(
150
+ [data_dict[source_type] for source_type in source_types], axis=0
151
+ )
152
+ data_dict['mixture'] = mixture
153
+ # shape: (channels_num, audio_samples)
154
+
155
+ return data_dict
156
+
157
+
158
+ def collate_fn(list_data_dict: List[Dict]) -> Dict:
159
+ r"""Collate mini-batch data to inputs and targets for training.
160
+
161
+ Args:
162
+ list_data_dict: e.g., [
163
+ {'vocals': (channels_num, segment_samples),
164
+ 'accompaniment': (channels_num, segment_samples),
165
+ 'mixture': (channels_num, segment_samples)
166
+ },
167
+ {'vocals': (channels_num, segment_samples),
168
+ 'accompaniment': (channels_num, segment_samples),
169
+ 'mixture': (channels_num, segment_samples)
170
+ },
171
+ ...]
172
+
173
+ Returns:
174
+ data_dict: e.g. {
175
+ 'vocals': (batch_size, channels_num, segment_samples),
176
+ 'accompaniment': (batch_size, channels_num, segment_samples),
177
+ 'mixture': (batch_size, channels_num, segment_samples)
178
+ }
179
+ """
180
+ data_dict = {}
181
+
182
+ for key in list_data_dict[0].keys():
183
+ data_dict[key] = torch.Tensor(
184
+ np.array([data_dict[key] for data_dict in list_data_dict])
185
+ )
186
+
187
+ return data_dict
bytesep/data/samplers.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from typing import Dict, List, NoReturn
3
+
4
+ import numpy as np
5
+ import torch.distributed as dist
6
+
7
+
8
+ class SegmentSampler:
9
+ def __init__(
10
+ self,
11
+ indexes_path: str,
12
+ segment_samples: int,
13
+ mixaudio_dict: Dict,
14
+ batch_size: int,
15
+ steps_per_epoch: int,
16
+ random_seed=1234,
17
+ ):
18
+ r"""Sample training indexes of sources.
19
+
20
+ Args:
21
+ indexes_path: str, path of indexes dict
22
+ segment_samplers: int
23
+ mixaudio_dict, dict, including hyper-parameters for mix-audio data
24
+ augmentation, e.g., {'voclas': 2, 'accompaniment': 2}
25
+ batch_size: int
26
+ steps_per_epoch: int, #steps_per_epoch is called an `epoch`
27
+ random_seed: int
28
+ """
29
+ self.segment_samples = segment_samples
30
+ self.mixaudio_dict = mixaudio_dict
31
+ self.batch_size = batch_size
32
+ self.steps_per_epoch = steps_per_epoch
33
+
34
+ self.meta_dict = pickle.load(open(indexes_path, "rb"))
35
+ # E.g., {
36
+ # 'vocals': [
37
+ # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300},
38
+ # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
39
+ # ...
40
+ # ],
41
+ # 'accompaniment': [
42
+ # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300},
43
+ # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
44
+ # ...
45
+ # ]
46
+ # }
47
+
48
+ self.source_types = self.meta_dict.keys()
49
+ # E.g., ['vocals', 'accompaniment']
50
+
51
+ self.pointers_dict = {source_type: 0 for source_type in self.source_types}
52
+ # E.g., {'vocals': 0, 'accompaniment': 0}
53
+
54
+ self.indexes_dict = {
55
+ source_type: np.arange(len(self.meta_dict[source_type]))
56
+ for source_type in self.source_types
57
+ }
58
+ # E.g. {
59
+ # 'vocals': [0, 1, ..., 225751],
60
+ # 'accompaniment': [0, 1, ..., 225751]
61
+ # }
62
+
63
+ self.random_state = np.random.RandomState(random_seed)
64
+
65
+ # Shuffle indexes.
66
+ for source_type in self.source_types:
67
+ self.random_state.shuffle(self.indexes_dict[source_type])
68
+ print("{}: {}".format(source_type, len(self.indexes_dict[source_type])))
69
+
70
+ def __iter__(self) -> List[Dict]:
71
+ r"""Yield a batch of meta info.
72
+
73
+ Returns:
74
+ batch_meta_list: (batch_size,) e.g., when mix-audio is 2, looks like [
75
+ {'vocals': [
76
+ {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
77
+ {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
78
+ 'accompaniment': [
79
+ {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 14579460, 'end_sample': 14711760},
80
+ {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 3995460, 'end_sample': 4127760}]
81
+ }
82
+ ...
83
+ ]
84
+ """
85
+ batch_size = self.batch_size
86
+
87
+ while True:
88
+ batch_meta_dict = {source_type: [] for source_type in self.source_types}
89
+
90
+ for source_type in self.source_types:
91
+ # E.g., ['vocals', 'accompaniment']
92
+
93
+ # Loop until get a mini-batch.
94
+ while len(batch_meta_dict[source_type]) != batch_size:
95
+
96
+ largest_index = (
97
+ len(self.indexes_dict[source_type])
98
+ - self.mixaudio_dict[source_type]
99
+ )
100
+ # E.g., 225750 = 225752 - 2
101
+
102
+ if self.pointers_dict[source_type] > largest_index:
103
+
104
+ # Reset pointer, and shuffle indexes.
105
+ self.pointers_dict[source_type] = 0
106
+ self.random_state.shuffle(self.indexes_dict[source_type])
107
+
108
+ source_metas = []
109
+ mix_audios_num = self.mixaudio_dict[source_type]
110
+
111
+ for _ in range(mix_audios_num):
112
+
113
+ pointer = self.pointers_dict[source_type]
114
+ # E.g., 1
115
+
116
+ index = self.indexes_dict[source_type][pointer]
117
+ # E.g., 12231
118
+
119
+ self.pointers_dict[source_type] += 1
120
+
121
+ source_meta = self.meta_dict[source_type][index]
122
+ # E.g., ['song_A.h5', 198450, 330750]
123
+
124
+ # source_metas.append(new_source_meta)
125
+ source_metas.append(source_meta)
126
+
127
+ batch_meta_dict[source_type].append(source_metas)
128
+ # When mix-audio is 2, batch_meta_dict looks like: {
129
+ # 'vocals': [
130
+ # [{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
131
+ # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}],
132
+ # [{'hdf5_path': 'songC.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1186290, 'end_sample': 1318590},
133
+ # {'hdf5_path': 'songD.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 8462790, 'end_sample': 8595090}]
134
+ # ]
135
+ # 'accompaniment': [
136
+ # [{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 24232950, 'end_sample': 24365250},
137
+ # {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1569960, 'end_sample': 1702260}],
138
+ # [{'hdf5_path': 'songG.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 2795940, 'end_sample': 2928240},
139
+ # {'hdf5_path': 'songH.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 10923570, 'end_sample': 11055870}]
140
+ # ]
141
+ # }
142
+
143
+ batch_meta_list = [
144
+ {
145
+ source_type: batch_meta_dict[source_type][i]
146
+ for source_type in self.source_types
147
+ }
148
+ for i in range(batch_size)
149
+ ]
150
+ # When mix-audio is 2, batch_meta_list looks like: [
151
+ # {'vocals': [
152
+ # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
153
+ # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
154
+ # 'accompaniment': [
155
+ # {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 14579460, 'end_sample': 14711760},
156
+ # {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 3995460, 'end_sample': 4127760}]
157
+ # }
158
+ # ...
159
+ # ]
160
+
161
+ yield batch_meta_list
162
+
163
+ def __len__(self) -> int:
164
+ return self.steps_per_epoch
165
+
166
+ def state_dict(self) -> Dict:
167
+ state = {'pointers_dict': self.pointers_dict, 'indexes_dict': self.indexes_dict}
168
+ return state
169
+
170
+ def load_state_dict(self, state) -> NoReturn:
171
+ self.pointers_dict = state['pointers_dict']
172
+ self.indexes_dict = state['indexes_dict']
173
+
174
+
175
+ class DistributedSamplerWrapper:
176
+ def __init__(self, sampler):
177
+ r"""Distributed wrapper of sampler."""
178
+ self.sampler = sampler
179
+
180
+ def __iter__(self):
181
+ num_replicas = dist.get_world_size()
182
+ rank = dist.get_rank()
183
+
184
+ for indices in self.sampler:
185
+ yield indices[rank::num_replicas]
186
+
187
+ def __len__(self) -> int:
188
+ return len(self.sampler)
bytesep/dataset_creation/__init__.py ADDED
File without changes
bytesep/dataset_creation/create_evaluation_audios/__init__.py ADDED
File without changes
bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import NoReturn
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import soundfile
8
+
9
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
10
+ read_csv as read_instruments_solo_csv,
11
+ )
12
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.maestro import (
13
+ read_csv as read_maestro_csv,
14
+ )
15
+ from bytesep.utils import load_random_segment
16
+
17
+
18
+ def create_evaluation(args) -> NoReturn:
19
+ r"""Random mix and write out audios for evaluation.
20
+
21
+ Args:
22
+ piano_dataset_dir: str, the directory of the piano dataset
23
+ symphony_dataset_dir: str, the directory of the symphony dataset
24
+ evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
25
+ sample_rate: int
26
+ channels: int, e.g., 1 | 2
27
+ evaluation_segments_num: int
28
+ mono: bool
29
+
30
+ Returns:
31
+ NoReturn
32
+ """
33
+
34
+ # arguments & parameters
35
+ piano_dataset_dir = args.piano_dataset_dir
36
+ symphony_dataset_dir = args.symphony_dataset_dir
37
+ evaluation_audios_dir = args.evaluation_audios_dir
38
+ sample_rate = args.sample_rate
39
+ channels = args.channels
40
+ evaluation_segments_num = args.evaluation_segments_num
41
+ mono = True if channels == 1 else False
42
+
43
+ split = 'test'
44
+ segment_seconds = 10.0
45
+
46
+ random_state = np.random.RandomState(1234)
47
+
48
+ piano_meta_csv = os.path.join(piano_dataset_dir, 'maestro-v2.0.0.csv')
49
+ piano_names_dict = read_maestro_csv(piano_meta_csv)
50
+ piano_audio_names = piano_names_dict[split]
51
+
52
+ symphony_meta_csv = os.path.join(symphony_dataset_dir, 'validation.csv')
53
+ symphony_names_dict = read_instruments_solo_csv(symphony_meta_csv)
54
+ symphony_audio_names = symphony_names_dict[split]
55
+
56
+ for source_type in ['piano', 'symphony', 'mixture']:
57
+ output_dir = os.path.join(evaluation_audios_dir, split, source_type)
58
+ os.makedirs(output_dir, exist_ok=True)
59
+
60
+ for n in range(evaluation_segments_num):
61
+
62
+ print('{} / {}'.format(n, evaluation_segments_num))
63
+
64
+ # Randomly select and write out a clean piano segment.
65
+ piano_audio_name = random_state.choice(piano_audio_names)
66
+ piano_audio_path = os.path.join(piano_dataset_dir, piano_audio_name)
67
+
68
+ piano_audio = load_random_segment(
69
+ audio_path=piano_audio_path,
70
+ random_state=random_state,
71
+ segment_seconds=segment_seconds,
72
+ mono=mono,
73
+ sample_rate=sample_rate,
74
+ )
75
+
76
+ output_piano_path = os.path.join(
77
+ evaluation_audios_dir, split, 'piano', '{:04d}.wav'.format(n)
78
+ )
79
+ soundfile.write(
80
+ file=output_piano_path, data=piano_audio.T, samplerate=sample_rate
81
+ )
82
+ print("Write out to {}".format(output_piano_path))
83
+
84
+ # Randomly select and write out a clean symphony segment.
85
+ symphony_audio_name = random_state.choice(symphony_audio_names)
86
+ symphony_audio_path = os.path.join(
87
+ symphony_dataset_dir, "mp3s", symphony_audio_name
88
+ )
89
+
90
+ symphony_audio = load_random_segment(
91
+ audio_path=symphony_audio_path,
92
+ random_state=random_state,
93
+ segment_seconds=segment_seconds,
94
+ mono=mono,
95
+ sample_rate=sample_rate,
96
+ )
97
+
98
+ output_symphony_path = os.path.join(
99
+ evaluation_audios_dir, split, 'symphony', '{:04d}.wav'.format(n)
100
+ )
101
+ soundfile.write(
102
+ file=output_symphony_path, data=symphony_audio.T, samplerate=sample_rate
103
+ )
104
+ print("Write out to {}".format(output_symphony_path))
105
+
106
+ # Mix piano and symphony segments and write out a mixture segment.
107
+ mixture_audio = symphony_audio + piano_audio
108
+ output_mixture_path = os.path.join(
109
+ evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
110
+ )
111
+ soundfile.write(
112
+ file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
113
+ )
114
+ print("Write out to {}".format(output_mixture_path))
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+
120
+ parser.add_argument(
121
+ "--piano_dataset_dir",
122
+ type=str,
123
+ required=True,
124
+ help="The directory of the piano dataset.",
125
+ )
126
+ parser.add_argument(
127
+ "--symphony_dataset_dir",
128
+ type=str,
129
+ required=True,
130
+ help="The directory of the symphony dataset.",
131
+ )
132
+ parser.add_argument(
133
+ "--evaluation_audios_dir",
134
+ type=str,
135
+ required=True,
136
+ help="The directory to write out randomly selected and mixed audio segments.",
137
+ )
138
+ parser.add_argument(
139
+ "--sample_rate",
140
+ type=int,
141
+ required=True,
142
+ help="Sample rate.",
143
+ )
144
+ parser.add_argument(
145
+ "--channels",
146
+ type=int,
147
+ required=True,
148
+ help="Audio channels, e.g, 1 or 2.",
149
+ )
150
+ parser.add_argument(
151
+ "--evaluation_segments_num",
152
+ type=int,
153
+ required=True,
154
+ help="The number of segments to create for evaluation.",
155
+ )
156
+
157
+ # Parse arguments.
158
+ args = parser.parse_args()
159
+
160
+ create_evaluation(args)
bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import soundfile
4
+ from typing import NoReturn
5
+
6
+ import musdb
7
+ import numpy as np
8
+
9
+ from bytesep.utils import load_audio
10
+
11
+
12
+ def create_evaluation(args) -> NoReturn:
13
+ r"""Random mix and write out audios for evaluation.
14
+
15
+ Args:
16
+ vctk_dataset_dir: str, the directory of the VCTK dataset
17
+ symphony_dataset_dir: str, the directory of the symphony dataset
18
+ evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
19
+ sample_rate: int
20
+ channels: int, e.g., 1 | 2
21
+ evaluation_segments_num: int
22
+ mono: bool
23
+
24
+ Returns:
25
+ NoReturn
26
+ """
27
+
28
+ # arguments & parameters
29
+ vctk_dataset_dir = args.vctk_dataset_dir
30
+ musdb18_dataset_dir = args.musdb18_dataset_dir
31
+ evaluation_audios_dir = args.evaluation_audios_dir
32
+ sample_rate = args.sample_rate
33
+ channels = args.channels
34
+ evaluation_segments_num = args.evaluation_segments_num
35
+ mono = True if channels == 1 else False
36
+
37
+ split = 'test'
38
+ random_state = np.random.RandomState(1234)
39
+
40
+ # paths
41
+ audios_dir = os.path.join(vctk_dataset_dir, "wav48", split)
42
+
43
+ for source_type in ['speech', 'music', 'mixture']:
44
+ output_dir = os.path.join(evaluation_audios_dir, split, source_type)
45
+ os.makedirs(output_dir, exist_ok=True)
46
+
47
+ # Get VCTK audio paths.
48
+ speech_audio_paths = []
49
+ speaker_ids = sorted(os.listdir(audios_dir))
50
+
51
+ for speaker_id in speaker_ids:
52
+ speaker_audios_dir = os.path.join(audios_dir, speaker_id)
53
+
54
+ audio_names = sorted(os.listdir(speaker_audios_dir))
55
+
56
+ for audio_name in audio_names:
57
+ speaker_audio_path = os.path.join(speaker_audios_dir, audio_name)
58
+ speech_audio_paths.append(speaker_audio_path)
59
+
60
+ # Get Musdb18 audio paths.
61
+ mus = musdb.DB(root=musdb18_dataset_dir, subsets=[split])
62
+ track_indexes = np.arange(len(mus.tracks))
63
+
64
+ for n in range(evaluation_segments_num):
65
+
66
+ print('{} / {}'.format(n, evaluation_segments_num))
67
+
68
+ # Randomly select and write out a clean speech segment.
69
+ speech_audio_path = random_state.choice(speech_audio_paths)
70
+
71
+ speech_audio = load_audio(
72
+ audio_path=speech_audio_path, mono=mono, sample_rate=sample_rate
73
+ )
74
+ # (channels_num, audio_samples)
75
+
76
+ if channels == 2:
77
+ speech_audio = np.tile(speech_audio, (2, 1))
78
+ # (channels_num, audio_samples)
79
+
80
+ output_speech_path = os.path.join(
81
+ evaluation_audios_dir, split, 'speech', '{:04d}.wav'.format(n)
82
+ )
83
+ soundfile.write(
84
+ file=output_speech_path, data=speech_audio.T, samplerate=sample_rate
85
+ )
86
+ print("Write out to {}".format(output_speech_path))
87
+
88
+ # Randomly select and write out a clean music segment.
89
+ track_index = random_state.choice(track_indexes)
90
+ track = mus[track_index]
91
+
92
+ segment_samples = speech_audio.shape[1]
93
+ start_sample = int(
94
+ random_state.uniform(0.0, segment_samples - speech_audio.shape[1])
95
+ )
96
+
97
+ music_audio = track.audio[start_sample : start_sample + segment_samples, :].T
98
+ # (channels_num, audio_samples)
99
+
100
+ output_music_path = os.path.join(
101
+ evaluation_audios_dir, split, 'music', '{:04d}.wav'.format(n)
102
+ )
103
+ soundfile.write(
104
+ file=output_music_path, data=music_audio.T, samplerate=sample_rate
105
+ )
106
+ print("Write out to {}".format(output_music_path))
107
+
108
+ # Mix speech and music segments and write out a mixture segment.
109
+ mixture_audio = speech_audio + music_audio
110
+ # (channels_num, audio_samples)
111
+
112
+ output_mixture_path = os.path.join(
113
+ evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
114
+ )
115
+ soundfile.write(
116
+ file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
117
+ )
118
+ print("Write out to {}".format(output_mixture_path))
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = argparse.ArgumentParser()
123
+
124
+ parser.add_argument(
125
+ "--vctk_dataset_dir",
126
+ type=str,
127
+ required=True,
128
+ help="The directory of the VCTK dataset.",
129
+ )
130
+ parser.add_argument(
131
+ "--musdb18_dataset_dir",
132
+ type=str,
133
+ required=True,
134
+ help="The directory of the MUSDB18 dataset.",
135
+ )
136
+ parser.add_argument(
137
+ "--evaluation_audios_dir",
138
+ type=str,
139
+ required=True,
140
+ help="The directory to write out randomly selected and mixed audio segments.",
141
+ )
142
+ parser.add_argument(
143
+ "--sample_rate",
144
+ type=int,
145
+ required=True,
146
+ help="Sample rate",
147
+ )
148
+ parser.add_argument(
149
+ "--channels",
150
+ type=int,
151
+ required=True,
152
+ help="Audio channels, e.g, 1 or 2.",
153
+ )
154
+ parser.add_argument(
155
+ "--evaluation_segments_num",
156
+ type=int,
157
+ required=True,
158
+ help="The number of segments to create for evaluation.",
159
+ )
160
+
161
+ # Parse arguments.
162
+ args = parser.parse_args()
163
+
164
+ create_evaluation(args)
bytesep/dataset_creation/create_evaluation_audios/violin-piano.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import NoReturn
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import soundfile
8
+
9
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
10
+ read_csv as read_instruments_solo_csv,
11
+ )
12
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.maestro import (
13
+ read_csv as read_maestro_csv,
14
+ )
15
+ from bytesep.utils import load_random_segment
16
+
17
+
18
+ def create_evaluation(args) -> NoReturn:
19
+ r"""Random mix and write out audios for evaluation.
20
+
21
+ Args:
22
+ violin_dataset_dir: str, the directory of the violin dataset
23
+ piano_dataset_dir: str, the directory of the piano dataset
24
+ evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
25
+ sample_rate: int
26
+ channels: int, e.g., 1 | 2
27
+ evaluation_segments_num: int
28
+ mono: bool
29
+
30
+ Returns:
31
+ NoReturn
32
+ """
33
+
34
+ # arguments & parameters
35
+ violin_dataset_dir = args.violin_dataset_dir
36
+ piano_dataset_dir = args.piano_dataset_dir
37
+ evaluation_audios_dir = args.evaluation_audios_dir
38
+ sample_rate = args.sample_rate
39
+ channels = args.channels
40
+ evaluation_segments_num = args.evaluation_segments_num
41
+ mono = True if channels == 1 else False
42
+
43
+ split = 'test'
44
+ segment_seconds = 10.0
45
+
46
+ random_state = np.random.RandomState(1234)
47
+
48
+ violin_meta_csv = os.path.join(violin_dataset_dir, 'validation.csv')
49
+ violin_names_dict = read_instruments_solo_csv(violin_meta_csv)
50
+ violin_audio_names = violin_names_dict['{}'.format(split)]
51
+
52
+ piano_meta_csv = os.path.join(piano_dataset_dir, 'maestro-v2.0.0.csv')
53
+ piano_names_dict = read_maestro_csv(piano_meta_csv)
54
+ piano_audio_names = piano_names_dict['{}'.format(split)]
55
+
56
+ for source_type in ['violin', 'piano', 'mixture']:
57
+ output_dir = os.path.join(evaluation_audios_dir, split, source_type)
58
+ os.makedirs(output_dir, exist_ok=True)
59
+
60
+ for n in range(evaluation_segments_num):
61
+
62
+ print('{} / {}'.format(n, evaluation_segments_num))
63
+
64
+ # Randomly select and write out a clean violin segment.
65
+ violin_audio_name = random_state.choice(violin_audio_names)
66
+ violin_audio_path = os.path.join(violin_dataset_dir, "mp3s", violin_audio_name)
67
+
68
+ violin_audio = load_random_segment(
69
+ audio_path=violin_audio_path,
70
+ random_state=random_state,
71
+ segment_seconds=segment_seconds,
72
+ mono=mono,
73
+ sample_rate=sample_rate,
74
+ )
75
+ # (channels_num, audio_samples)
76
+
77
+ output_violin_path = os.path.join(
78
+ evaluation_audios_dir, split, 'violin', '{:04d}.wav'.format(n)
79
+ )
80
+ soundfile.write(
81
+ file=output_violin_path, data=violin_audio.T, samplerate=sample_rate
82
+ )
83
+ print("Write out to {}".format(output_violin_path))
84
+
85
+ # Randomly select and write out a clean piano segment.
86
+ piano_audio_name = random_state.choice(piano_audio_names)
87
+ piano_audio_path = os.path.join(piano_dataset_dir, piano_audio_name)
88
+
89
+ piano_audio = load_random_segment(
90
+ audio_path=piano_audio_path,
91
+ random_state=random_state,
92
+ segment_seconds=segment_seconds,
93
+ mono=mono,
94
+ sample_rate=sample_rate,
95
+ )
96
+ # (channels_num, audio_samples)
97
+
98
+ output_piano_path = os.path.join(
99
+ evaluation_audios_dir, split, 'piano', '{:04d}.wav'.format(n)
100
+ )
101
+ soundfile.write(
102
+ file=output_piano_path, data=piano_audio.T, samplerate=sample_rate
103
+ )
104
+ print("Write out to {}".format(output_piano_path))
105
+
106
+ # Mix violin and piano segments and write out a mixture segment.
107
+ mixture_audio = violin_audio + piano_audio
108
+ # (channels_num, audio_samples)
109
+
110
+ output_mixture_path = os.path.join(
111
+ evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
112
+ )
113
+ soundfile.write(
114
+ file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
115
+ )
116
+ print("Write out to {}".format(output_mixture_path))
117
+
118
+
119
+ if __name__ == "__main__":
120
+ parser = argparse.ArgumentParser()
121
+
122
+ parser.add_argument(
123
+ "--violin_dataset_dir",
124
+ type=str,
125
+ required=True,
126
+ help="The directory of the violin dataset.",
127
+ )
128
+ parser.add_argument(
129
+ "--piano_dataset_dir",
130
+ type=str,
131
+ required=True,
132
+ help="The directory of the piano dataset.",
133
+ )
134
+ parser.add_argument(
135
+ "--evaluation_audios_dir",
136
+ type=str,
137
+ required=True,
138
+ help="The directory to write out randomly selected and mixed audio segments.",
139
+ )
140
+ parser.add_argument(
141
+ "--sample_rate",
142
+ type=int,
143
+ required=True,
144
+ help="Sample rate",
145
+ )
146
+ parser.add_argument(
147
+ "--channels",
148
+ type=int,
149
+ required=True,
150
+ help="Audio channels, e.g, 1 or 2.",
151
+ )
152
+ parser.add_argument(
153
+ "--evaluation_segments_num",
154
+ type=int,
155
+ required=True,
156
+ help="The number of segments to create for evaluation.",
157
+ )
158
+
159
+ # Parse arguments.
160
+ args = parser.parse_args()
161
+
162
+ create_evaluation(args)
bytesep/dataset_creation/create_indexes/__init__.py ADDED
File without changes
bytesep/dataset_creation/create_indexes/create_indexes.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+ from typing import NoReturn
5
+
6
+ import h5py
7
+
8
+ from bytesep.utils import read_yaml
9
+
10
+
11
+ def create_indexes(args) -> NoReturn:
12
+ r"""Create and write out training indexes into disk. The indexes may contain
13
+ information from multiple datasets. During training, training indexes will
14
+ be shuffled and iterated for selecting segments to be mixed. E.g., the
15
+ training indexes_dict looks like: {
16
+ 'vocals': [
17
+ {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
18
+ {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
19
+ ...
20
+ ]
21
+ 'accompaniment': [
22
+ {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
23
+ {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
24
+ ...
25
+ ]
26
+ }
27
+ """
28
+
29
+ # Arugments & parameters
30
+ workspace = args.workspace
31
+ config_yaml = args.config_yaml
32
+
33
+ # Only create indexes for training, because evalution is on entire pieces.
34
+ split = "train"
35
+
36
+ # Read config file.
37
+ configs = read_yaml(config_yaml)
38
+
39
+ sample_rate = configs["sample_rate"]
40
+ segment_samples = int(configs["segment_seconds"] * sample_rate)
41
+
42
+ # Path to write out index.
43
+ indexes_path = os.path.join(workspace, configs[split]["indexes"])
44
+ os.makedirs(os.path.dirname(indexes_path), exist_ok=True)
45
+
46
+ source_types = configs[split]["source_types"].keys()
47
+ # E.g., ['vocals', 'accompaniment']
48
+
49
+ indexes_dict = {source_type: [] for source_type in source_types}
50
+ # E.g., indexes_dict will looks like: {
51
+ # 'vocals': [
52
+ # {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
53
+ # {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
54
+ # ...
55
+ # ]
56
+ # 'accompaniment': [
57
+ # {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
58
+ # {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
59
+ # ...
60
+ # ]
61
+ # }
62
+
63
+ # Get training indexes for each source type.
64
+ for source_type in source_types:
65
+ # E.g., ['vocals', 'bass', ...]
66
+
67
+ print("--- {} ---".format(source_type))
68
+
69
+ dataset_types = configs[split]["source_types"][source_type]
70
+ # E.g., ['musdb18', ...]
71
+
72
+ # Each source can come from mulitple datasets.
73
+ for dataset_type in dataset_types:
74
+
75
+ hdf5s_dir = os.path.join(
76
+ workspace, dataset_types[dataset_type]["hdf5s_directory"]
77
+ )
78
+
79
+ hop_samples = int(dataset_types[dataset_type]["hop_seconds"] * sample_rate)
80
+
81
+ key_in_hdf5 = dataset_types[dataset_type]["key_in_hdf5"]
82
+ # E.g., 'vocals'
83
+
84
+ hdf5_names = sorted(os.listdir(hdf5s_dir))
85
+ print("Hdf5 files num: {}".format(len(hdf5_names)))
86
+
87
+ # Traverse all packed hdf5 files of a dataset.
88
+ for n, hdf5_name in enumerate(hdf5_names):
89
+
90
+ print(n, hdf5_name)
91
+ hdf5_path = os.path.join(hdf5s_dir, hdf5_name)
92
+
93
+ with h5py.File(hdf5_path, "r") as hf:
94
+
95
+ bgn_sample = 0
96
+ while bgn_sample + segment_samples < hf[key_in_hdf5].shape[-1]:
97
+ meta = {
98
+ 'hdf5_path': hdf5_path,
99
+ 'key_in_hdf5': key_in_hdf5,
100
+ 'begin_sample': bgn_sample,
101
+ 'end_sample': bgn_sample + segment_samples,
102
+ }
103
+ indexes_dict[source_type].append(meta)
104
+
105
+ bgn_sample += hop_samples
106
+
107
+ # If the audio length is shorter than the segment length,
108
+ # then use the entire audio as a segment.
109
+ if bgn_sample == 0:
110
+ meta = {
111
+ 'hdf5_path': hdf5_path,
112
+ 'key_in_hdf5': key_in_hdf5,
113
+ 'begin_sample': 0,
114
+ 'end_sample': segment_samples,
115
+ }
116
+ indexes_dict[source_type].append(meta)
117
+
118
+ print(
119
+ "Total indexes for {}: {}".format(
120
+ source_type, len(indexes_dict[source_type])
121
+ )
122
+ )
123
+
124
+ pickle.dump(indexes_dict, open(indexes_path, "wb"))
125
+ print("Write index dict to {}".format(indexes_path))
126
+
127
+
128
+ if __name__ == "__main__":
129
+ parser = argparse.ArgumentParser()
130
+
131
+ parser.add_argument(
132
+ "--workspace", type=str, required=True, help="Directory of workspace."
133
+ )
134
+ parser.add_argument(
135
+ "--config_yaml", type=str, required=True, help="User defined config file."
136
+ )
137
+
138
+ # Parse arguments.
139
+ args = parser.parse_args()
140
+
141
+ # Create training indexes.
142
+ create_indexes(args)
bytesep/dataset_creation/pack_audios_to_hdf5s/__init__.py ADDED
File without changes
bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import time
5
+ from concurrent.futures import ProcessPoolExecutor
6
+ from typing import Dict, List, NoReturn
7
+
8
+ import h5py
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ from bytesep.utils import float32_to_int16, load_audio
13
+
14
+
15
+ def read_csv(meta_csv) -> Dict:
16
+ r"""Get train & test names from csv.
17
+
18
+ Args:
19
+ meta_csv: str
20
+
21
+ Returns:
22
+ names_dict: dict, e.g., {
23
+ 'train', ['songA.mp3', 'songB.mp3', ...],
24
+ 'test': ['songE.mp3', 'songF.mp3', ...]
25
+ }
26
+ """
27
+ df = pd.read_csv(meta_csv, sep=',')
28
+
29
+ names_dict = {}
30
+
31
+ for split in ['train', 'test']:
32
+ audio_indexes = df['split'] == split
33
+ audio_names = list(df['audio_name'][audio_indexes])
34
+ audio_names = [
35
+ '{}.mp3'.format(pathlib.Path(audio_name).stem) for audio_name in audio_names
36
+ ]
37
+ names_dict[split] = audio_names
38
+
39
+ return names_dict
40
+
41
+
42
+ def pack_audios_to_hdf5s(args) -> NoReturn:
43
+ r"""Pack (resampled) audio files into hdf5 files to speed up loading.
44
+
45
+ Args:
46
+ dataset_dir: str
47
+ split: str, 'train' | 'test'
48
+ source_type: str
49
+ hdf5s_dir: str, directory to write out hdf5 files
50
+ sample_rate: int
51
+ channels_num: int
52
+ mono: bool
53
+
54
+ Returns:
55
+ NoReturn
56
+ """
57
+
58
+ # arguments & parameters
59
+ dataset_dir = args.dataset_dir
60
+ split = args.split
61
+ source_type = args.source_type
62
+ hdf5s_dir = args.hdf5s_dir
63
+ sample_rate = args.sample_rate
64
+ channels = args.channels
65
+ mono = True if channels == 1 else False
66
+
67
+ # Only pack data for training data.
68
+ assert split == "train"
69
+
70
+ # paths
71
+ audios_dir = os.path.join(dataset_dir, 'mp3s')
72
+ meta_csv = os.path.join(dataset_dir, 'validation.csv')
73
+
74
+ os.makedirs(hdf5s_dir, exist_ok=True)
75
+
76
+ # Read train & test names.
77
+ names_dict = read_csv(meta_csv)
78
+
79
+ audio_names = names_dict[split]
80
+
81
+ params = []
82
+
83
+ for audio_index, audio_name in enumerate(audio_names):
84
+
85
+ audio_path = os.path.join(audios_dir, audio_name)
86
+
87
+ hdf5_path = os.path.join(
88
+ hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
89
+ )
90
+
91
+ param = (
92
+ audio_index,
93
+ audio_name,
94
+ source_type,
95
+ audio_path,
96
+ mono,
97
+ sample_rate,
98
+ hdf5_path,
99
+ )
100
+ params.append(param)
101
+
102
+ # Uncomment for debug.
103
+ # write_single_audio_to_hdf5(params[0])
104
+ # os._exit()
105
+
106
+ pack_hdf5s_time = time.time()
107
+
108
+ with ProcessPoolExecutor(max_workers=None) as pool:
109
+ # Maximum works on the machine
110
+ pool.map(write_single_audio_to_hdf5, params)
111
+
112
+ print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
113
+
114
+
115
+ def write_single_audio_to_hdf5(param: List) -> NoReturn:
116
+ r"""Write single audio into hdf5 file."""
117
+
118
+ (
119
+ audio_index,
120
+ audio_name,
121
+ source_type,
122
+ audio_path,
123
+ mono,
124
+ sample_rate,
125
+ hdf5_path,
126
+ ) = param
127
+
128
+ with h5py.File(hdf5_path, "w") as hf:
129
+
130
+ hf.attrs.create("audio_name", data=audio_name, dtype="S100")
131
+ hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
132
+
133
+ audio = load_audio(audio_path=audio_path, mono=mono, sample_rate=sample_rate)
134
+ # audio: (channels_num, audio_samples)
135
+
136
+ hf.create_dataset(
137
+ name=source_type, data=float32_to_int16(audio), dtype=np.int16
138
+ )
139
+
140
+ print('{} Write hdf5 to {}'.format(audio_index, hdf5_path))
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+
146
+ parser.add_argument(
147
+ "--dataset_dir",
148
+ type=str,
149
+ required=True,
150
+ help="Directory of the instruments solo dataset.",
151
+ )
152
+ parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
153
+ parser.add_argument(
154
+ "--source_type",
155
+ type=str,
156
+ required=True,
157
+ )
158
+ parser.add_argument(
159
+ "--hdf5s_dir",
160
+ type=str,
161
+ required=True,
162
+ help="Directory to write out hdf5 files.",
163
+ )
164
+ parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
165
+ parser.add_argument(
166
+ "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
167
+ )
168
+
169
+ # Parse arguments.
170
+ args = parser.parse_args()
171
+
172
+ # Pack audios to hdf5 files.
173
+ pack_audios_to_hdf5s(args)
bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import time
5
+ from concurrent.futures import ProcessPoolExecutor
6
+ from typing import Dict, NoReturn
7
+
8
+ import pandas as pd
9
+
10
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
11
+ write_single_audio_to_hdf5,
12
+ )
13
+
14
+
15
+ def read_csv(meta_csv) -> Dict:
16
+ r"""Get train & test names from csv.
17
+
18
+ Args:
19
+ meta_csv: str
20
+
21
+ Returns:
22
+ names_dict: dict, e.g., {
23
+ 'train', ['a1.mp3', 'a2.mp3'],
24
+ 'test': ['b1.mp3', 'b2.mp3']
25
+ }
26
+ """
27
+ df = pd.read_csv(meta_csv, sep=',')
28
+
29
+ names_dict = {}
30
+
31
+ for split in ['train', 'test']:
32
+ audio_indexes = df['split'] == split
33
+ audio_names = list(df['audio_filename'][audio_indexes])
34
+ names_dict[split] = audio_names
35
+
36
+ return names_dict
37
+
38
+
39
+ def pack_audios_to_hdf5s(args) -> NoReturn:
40
+ r"""Pack (resampled) audio files into hdf5 files to speed up loading.
41
+
42
+ Args:
43
+ dataset_dir: str
44
+ split: str, 'train' | 'test'
45
+ hdf5s_dir: str, directory to write out hdf5 files
46
+ sample_rate: int
47
+ channels_num: int
48
+ mono: bool
49
+
50
+ Returns:
51
+ NoReturn
52
+ """
53
+
54
+ # arguments & parameters
55
+ dataset_dir = args.dataset_dir
56
+ split = args.split
57
+ hdf5s_dir = args.hdf5s_dir
58
+ sample_rate = args.sample_rate
59
+ channels = args.channels
60
+ mono = True if channels == 1 else False
61
+
62
+ source_type = "piano"
63
+
64
+ # Only pack data for training data.
65
+ assert split == "train"
66
+
67
+ # paths
68
+ meta_csv = os.path.join(dataset_dir, 'maestro-v2.0.0.csv')
69
+
70
+ os.makedirs(hdf5s_dir, exist_ok=True)
71
+
72
+ # Read train & test names.
73
+ names_dict = read_csv(meta_csv)
74
+
75
+ audio_names = names_dict['{}'.format(split)]
76
+
77
+ params = []
78
+
79
+ for audio_index, audio_name in enumerate(audio_names):
80
+
81
+ audio_path = os.path.join(dataset_dir, audio_name)
82
+
83
+ hdf5_path = os.path.join(
84
+ hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
85
+ )
86
+
87
+ param = (
88
+ audio_index,
89
+ audio_name,
90
+ source_type,
91
+ audio_path,
92
+ mono,
93
+ sample_rate,
94
+ hdf5_path,
95
+ )
96
+ params.append(param)
97
+
98
+ # Uncomment for debug.
99
+ # write_single_audio_to_hdf5(params[0])
100
+ # os._exit(0)
101
+
102
+ pack_hdf5s_time = time.time()
103
+
104
+ with ProcessPoolExecutor(max_workers=None) as pool:
105
+ # Maximum works on the machine
106
+ pool.map(write_single_audio_to_hdf5, params)
107
+
108
+ print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
109
+
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser()
113
+
114
+ parser.add_argument(
115
+ "--dataset_dir",
116
+ type=str,
117
+ required=True,
118
+ help="Directory of the MAESTRO dataset.",
119
+ )
120
+ parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
121
+ parser.add_argument(
122
+ "--hdf5s_dir",
123
+ type=str,
124
+ required=True,
125
+ help="Directory to write out hdf5 files.",
126
+ )
127
+ parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
128
+ parser.add_argument(
129
+ "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
130
+ )
131
+
132
+ # Parse arguments.
133
+ args = parser.parse_args()
134
+
135
+ # Pack audios to hdf5 files.
136
+ pack_audios_to_hdf5s(args)
bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ from concurrent.futures import ProcessPoolExecutor
5
+ from typing import NoReturn
6
+
7
+ import h5py
8
+ import librosa
9
+ import musdb
10
+ import numpy as np
11
+
12
+ from bytesep.utils import float32_to_int16
13
+
14
+ # Source types of the MUSDB18 dataset.
15
+ SOURCE_TYPES = ["vocals", "drums", "bass", "other", "accompaniment"]
16
+
17
+
18
+ def pack_audios_to_hdf5s(args) -> NoReturn:
19
+ r"""Pack (resampled) audio files into hdf5 files to speed up loading.
20
+
21
+ Args:
22
+ dataset_dir: str
23
+ subset: str, 'train' | 'test'
24
+ split: str, '' | 'train' | 'valid'
25
+ hdf5s_dir: str, directory to write out hdf5 files
26
+ sample_rate: int
27
+ channels_num: int
28
+ mono: bool
29
+
30
+ Returns:
31
+ NoReturn
32
+ """
33
+
34
+ # arguments & parameters
35
+ dataset_dir = args.dataset_dir
36
+ subset = args.subset
37
+ split = None if args.split == "" else args.split
38
+ hdf5s_dir = args.hdf5s_dir
39
+ sample_rate = args.sample_rate
40
+ channels = args.channels
41
+
42
+ mono = True if channels == 1 else False
43
+ source_types = SOURCE_TYPES
44
+ resample_type = "kaiser_fast"
45
+
46
+ # Paths
47
+ os.makedirs(hdf5s_dir, exist_ok=True)
48
+
49
+ # Dataset of corresponding subset and split.
50
+ mus = musdb.DB(root=dataset_dir, subsets=[subset], split=split)
51
+ print("Subset: {}, Split: {}, Total pieces: {}".format(subset, split, len(mus)))
52
+
53
+ params = [] # A list of params for multiple processing.
54
+
55
+ for track_index in range(len(mus.tracks)):
56
+
57
+ param = (
58
+ dataset_dir,
59
+ subset,
60
+ split,
61
+ track_index,
62
+ source_types,
63
+ mono,
64
+ sample_rate,
65
+ resample_type,
66
+ hdf5s_dir,
67
+ )
68
+
69
+ params.append(param)
70
+
71
+ # Uncomment for debug.
72
+ # write_single_audio_to_hdf5(params[0])
73
+ # os._exit(0)
74
+
75
+ pack_hdf5s_time = time.time()
76
+
77
+ with ProcessPoolExecutor(max_workers=None) as pool:
78
+ # Maximum works on the machine
79
+ pool.map(write_single_audio_to_hdf5, params)
80
+
81
+ print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
82
+
83
+
84
+ def write_single_audio_to_hdf5(param) -> NoReturn:
85
+ r"""Write single audio into hdf5 file."""
86
+ (
87
+ dataset_dir,
88
+ subset,
89
+ split,
90
+ track_index,
91
+ source_types,
92
+ mono,
93
+ sample_rate,
94
+ resample_type,
95
+ hdf5s_dir,
96
+ ) = param
97
+
98
+ # Dataset of corresponding subset and split.
99
+ mus = musdb.DB(root=dataset_dir, subsets=[subset], split=split)
100
+ track = mus.tracks[track_index]
101
+
102
+ # Path to write out hdf5 file.
103
+ hdf5_path = os.path.join(hdf5s_dir, "{}.h5".format(track.name))
104
+
105
+ with h5py.File(hdf5_path, "w") as hf:
106
+
107
+ hf.attrs.create("audio_name", data=track.name.encode(), dtype="S100")
108
+ hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
109
+
110
+ for source_type in source_types:
111
+
112
+ audio = track.targets[source_type].audio.T
113
+ # (channels_num, audio_samples)
114
+
115
+ # Preprocess audio to mono / stereo, and resample.
116
+ audio = preprocess_audio(
117
+ audio, mono, track.rate, sample_rate, resample_type
118
+ )
119
+ # audio = load_audio(audio_path=audio_path, mono=mono, sample_rate=sample_rate)
120
+ # (channels_num, audio_samples) | (audio_samples,)
121
+
122
+ hf.create_dataset(
123
+ name=source_type, data=float32_to_int16(audio), dtype=np.int16
124
+ )
125
+
126
+ # Mixture
127
+ audio = track.audio.T
128
+ # (channels_num, audio_samples)
129
+
130
+ # Preprocess audio to mono / stereo, and resample.
131
+ audio = preprocess_audio(audio, mono, track.rate, sample_rate, resample_type)
132
+ # (channels_num, audio_samples)
133
+
134
+ hf.create_dataset(name="mixture", data=float32_to_int16(audio), dtype=np.int16)
135
+
136
+ print("{} Write to {}, {}".format(track_index, hdf5_path, audio.shape))
137
+
138
+
139
+ def preprocess_audio(audio, mono, origin_sr, sr, resample_type) -> np.array:
140
+ r"""Preprocess audio to mono / stereo, and resample.
141
+
142
+ Args:
143
+ audio: (channels_num, audio_samples), input audio
144
+ mono: bool
145
+ origin_sr: float, original sample rate
146
+ sr: float, target sample rate
147
+ resample_type: str, e.g., 'kaiser_fast'
148
+
149
+ Returns:
150
+ output: ndarray, output audio
151
+ """
152
+ if mono:
153
+ audio = np.mean(audio, axis=0)
154
+ # (audio_samples,)
155
+
156
+ output = librosa.core.resample(
157
+ audio, orig_sr=origin_sr, target_sr=sr, res_type=resample_type
158
+ )
159
+ # (audio_samples,) | (channels_num, audio_samples)
160
+
161
+ if output.ndim == 1:
162
+ output = output[None, :]
163
+ # (1, audio_samples,)
164
+
165
+ return output
166
+
167
+
168
+ if __name__ == "__main__":
169
+ parser = argparse.ArgumentParser()
170
+
171
+ parser.add_argument(
172
+ "--dataset_dir",
173
+ type=str,
174
+ required=True,
175
+ help="Directory of the MUSDB18 dataset.",
176
+ )
177
+ parser.add_argument(
178
+ "--subset",
179
+ type=str,
180
+ required=True,
181
+ choices=["train", "test"],
182
+ help="Train subset: 100 pieces; test subset: 50 pieces.",
183
+ )
184
+ parser.add_argument(
185
+ "--split",
186
+ type=str,
187
+ required=True,
188
+ choices=["", "train", "valid"],
189
+ help="Use '' to use all 100 pieces to train. Use 'train' to use 86 \
190
+ pieces for train, and use 'test' to use 14 pieces for valid.",
191
+ )
192
+ parser.add_argument(
193
+ "--hdf5s_dir",
194
+ type=str,
195
+ required=True,
196
+ help="Directory to write out hdf5 files.",
197
+ )
198
+ parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
199
+ parser.add_argument(
200
+ "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
201
+ )
202
+
203
+ # Parse arguments.
204
+ args = parser.parse_args()
205
+
206
+ # Pack audios into hdf5 files.
207
+ pack_audios_to_hdf5s(args)
bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import time
5
+ from concurrent.futures import ProcessPoolExecutor
6
+ from typing import NoReturn
7
+
8
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
9
+ write_single_audio_to_hdf5,
10
+ )
11
+
12
+
13
+ def pack_audios_to_hdf5s(args) -> NoReturn:
14
+ r"""Pack (resampled) audio files into hdf5 files to speed up loading.
15
+
16
+ Args:
17
+ dataset_dir: str
18
+ split: str, 'train' | 'test'
19
+ hdf5s_dir: str, directory to write out hdf5 files
20
+ sample_rate: int
21
+ channels_num: int
22
+ mono: bool
23
+
24
+ Returns:
25
+ NoReturn
26
+ """
27
+
28
+ # arguments & parameters
29
+ dataset_dir = args.dataset_dir
30
+ split = args.split
31
+ hdf5s_dir = args.hdf5s_dir
32
+ sample_rate = args.sample_rate
33
+ channels = args.channels
34
+ mono = True if channels == 1 else False
35
+
36
+ source_type = "speech"
37
+
38
+ # Only pack data for training data.
39
+ assert split == "train"
40
+
41
+ audios_dir = os.path.join(dataset_dir, 'wav48', split)
42
+ os.makedirs(hdf5s_dir, exist_ok=True)
43
+
44
+ speaker_ids = sorted(os.listdir(audios_dir))
45
+
46
+ params = []
47
+ audio_index = 0
48
+
49
+ for speaker_id in speaker_ids:
50
+
51
+ speaker_audios_dir = os.path.join(audios_dir, speaker_id)
52
+
53
+ audio_names = sorted(os.listdir(speaker_audios_dir))
54
+
55
+ for audio_name in audio_names:
56
+
57
+ audio_path = os.path.join(speaker_audios_dir, audio_name)
58
+
59
+ hdf5_path = os.path.join(
60
+ hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
61
+ )
62
+
63
+ param = (
64
+ audio_index,
65
+ audio_name,
66
+ source_type,
67
+ audio_path,
68
+ mono,
69
+ sample_rate,
70
+ hdf5_path,
71
+ )
72
+ params.append(param)
73
+
74
+ audio_index += 1
75
+
76
+ # Uncomment for debug.
77
+ # write_single_audio_to_hdf5(params[0])
78
+ # os._exit(0)
79
+
80
+ pack_hdf5s_time = time.time()
81
+
82
+ with ProcessPoolExecutor(max_workers=None) as pool:
83
+ # Maximum works on the machine
84
+ pool.map(write_single_audio_to_hdf5, params)
85
+
86
+ print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
87
+
88
+
89
+ if __name__ == "__main__":
90
+ parser = argparse.ArgumentParser()
91
+
92
+ parser.add_argument(
93
+ "--dataset_dir",
94
+ type=str,
95
+ required=True,
96
+ help="Directory of the VCTK dataset.",
97
+ )
98
+ parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
99
+ parser.add_argument(
100
+ "--hdf5s_dir",
101
+ type=str,
102
+ required=True,
103
+ help="Directory to write out hdf5 files.",
104
+ )
105
+ parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
106
+ parser.add_argument(
107
+ "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
108
+ )
109
+
110
+ # Parse arguments.
111
+ args = parser.parse_args()
112
+
113
+ # Pack audios into hdf5 files.
114
+ pack_audios_to_hdf5s(args)
bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import time
5
+ from concurrent.futures import ProcessPoolExecutor
6
+ from typing import List, NoReturn
7
+
8
+ import h5py
9
+ import numpy as np
10
+
11
+ from bytesep.utils import float32_to_int16, load_audio
12
+
13
+
14
+ def pack_audios_to_hdf5s(args) -> NoReturn:
15
+ r"""Pack (resampled) audio files into hdf5 files to speed up loading.
16
+
17
+ Args:
18
+ dataset_dir: str
19
+ split: str, 'train' | 'test'
20
+ hdf5s_dir: str, directory to write out hdf5 files
21
+ sample_rate: int
22
+ channels_num: int
23
+ mono: bool
24
+
25
+ Returns:
26
+ NoReturn
27
+ """
28
+
29
+ # arguments & parameters
30
+ dataset_dir = args.dataset_dir
31
+ split = args.split
32
+ hdf5s_dir = args.hdf5s_dir
33
+ sample_rate = args.sample_rate
34
+ channels = args.channels
35
+ mono = True if channels == 1 else False
36
+
37
+ # Only pack data for training data.
38
+ assert split == "train"
39
+
40
+ speech_dir = os.path.join(dataset_dir, "clean_{}set_wav".format(split))
41
+ mixture_dir = os.path.join(dataset_dir, "noisy_{}set_wav".format(split))
42
+
43
+ os.makedirs(hdf5s_dir, exist_ok=True)
44
+
45
+ # Read names.
46
+ audio_names = sorted(os.listdir(speech_dir))
47
+
48
+ params = []
49
+
50
+ for audio_index, audio_name in enumerate(audio_names):
51
+
52
+ speech_path = os.path.join(speech_dir, audio_name)
53
+ mixture_path = os.path.join(mixture_dir, audio_name)
54
+
55
+ hdf5_path = os.path.join(
56
+ hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
57
+ )
58
+
59
+ param = (
60
+ audio_index,
61
+ audio_name,
62
+ speech_path,
63
+ mixture_path,
64
+ mono,
65
+ sample_rate,
66
+ hdf5_path,
67
+ )
68
+ params.append(param)
69
+
70
+ # Uncomment for debug.
71
+ # write_single_audio_to_hdf5(params[0])
72
+ # os._exit(0)
73
+
74
+ pack_hdf5s_time = time.time()
75
+
76
+ with ProcessPoolExecutor(max_workers=None) as pool:
77
+ # Maximum works on the machine
78
+ pool.map(write_single_audio_to_hdf5, params)
79
+
80
+ print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
81
+
82
+
83
+ def write_single_audio_to_hdf5(param: List) -> NoReturn:
84
+ r"""Write single audio into hdf5 file."""
85
+
86
+ (
87
+ audio_index,
88
+ audio_name,
89
+ speech_path,
90
+ mixture_path,
91
+ mono,
92
+ sample_rate,
93
+ hdf5_path,
94
+ ) = param
95
+
96
+ with h5py.File(hdf5_path, "w") as hf:
97
+
98
+ hf.attrs.create("audio_name", data=audio_name, dtype="S100")
99
+ hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
100
+
101
+ speech = load_audio(audio_path=speech_path, mono=mono, sample_rate=sample_rate)
102
+ # speech: (channels_num, audio_samples)
103
+
104
+ mixture = load_audio(
105
+ audio_path=mixture_path, mono=mono, sample_rate=sample_rate
106
+ )
107
+ # mixture: (channels_num, audio_samples)
108
+
109
+ noise = mixture - speech
110
+ # noise: (channels_num, audio_samples)
111
+
112
+ hf.create_dataset(name='speech', data=float32_to_int16(speech), dtype=np.int16)
113
+ hf.create_dataset(name='noise', data=float32_to_int16(noise), dtype=np.int16)
114
+
115
+ print('{} Write hdf5 to {}'.format(audio_index, hdf5_path))
116
+
117
+
118
+ if __name__ == "__main__":
119
+ parser = argparse.ArgumentParser()
120
+
121
+ parser.add_argument(
122
+ "--dataset_dir",
123
+ type=str,
124
+ required=True,
125
+ help="Directory of the Voicebank-Demand dataset.",
126
+ )
127
+ parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
128
+ parser.add_argument(
129
+ "--hdf5s_dir",
130
+ type=str,
131
+ required=True,
132
+ help="Directory to write out hdf5 files.",
133
+ )
134
+ parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
135
+ parser.add_argument(
136
+ "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
137
+ )
138
+
139
+ # Parse arguments.
140
+ args = parser.parse_args()
141
+
142
+ # Pack audios into hdf5 files.
143
+ pack_audios_to_hdf5s(args)
bytesep/inference.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('.')
3
+ import argparse
4
+ import os
5
+ import time
6
+ from typing import Dict
7
+ import pathlib
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import soundfile
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from bytesep.models.lightning_modules import get_model_class
16
+ from bytesep.utils import read_yaml
17
+
18
+
19
+ class Separator:
20
+ def __init__(
21
+ self, model: nn.Module, segment_samples: int, batch_size: int, device: str
22
+ ):
23
+ r"""Separate to separate an audio clip into a target source.
24
+
25
+ Args:
26
+ model: nn.Module, trained model
27
+ segment_samples: int, length of segments to be input to a model, e.g., 44100*30
28
+ batch_size, int, e.g., 12
29
+ device: str, e.g., 'cuda'
30
+ """
31
+ self.model = model
32
+ self.segment_samples = segment_samples
33
+ self.batch_size = batch_size
34
+ self.device = device
35
+
36
+ def separate(self, input_dict: Dict) -> np.array:
37
+ r"""Separate an audio clip into a target source.
38
+
39
+ Args:
40
+ input_dict: dict, e.g., {
41
+ waveform: (channels_num, audio_samples),
42
+ ...,
43
+ }
44
+
45
+ Returns:
46
+ sep_audio: (channels_num, audio_samples) | (target_sources_num, channels_num, audio_samples)
47
+ """
48
+ audio = input_dict['waveform']
49
+
50
+ audio_samples = audio.shape[-1]
51
+
52
+ # Pad the audio with zero in the end so that the length of audio can be
53
+ # evenly divided by segment_samples.
54
+ audio = self.pad_audio(audio)
55
+
56
+ # Enframe long audio into segments.
57
+ segments = self.enframe(audio, self.segment_samples)
58
+ # (segments_num, channels_num, segment_samples)
59
+
60
+ segments_input_dict = {'waveform': segments}
61
+
62
+ if 'condition' in input_dict.keys():
63
+ segments_num = len(segments)
64
+ segments_input_dict['condition'] = np.tile(
65
+ input_dict['condition'][None, :], (segments_num, 1)
66
+ )
67
+ # (batch_size, segments_num)
68
+
69
+ # Separate in mini-batches.
70
+ sep_segments = self._forward_in_mini_batches(
71
+ self.model, segments_input_dict, self.batch_size
72
+ )['waveform']
73
+ # (segments_num, channels_num, segment_samples)
74
+
75
+ # Deframe segments into long audio.
76
+ sep_audio = self.deframe(sep_segments)
77
+ # (channels_num, padded_audio_samples)
78
+
79
+ sep_audio = sep_audio[:, 0:audio_samples]
80
+ # (channels_num, audio_samples)
81
+
82
+ return sep_audio
83
+
84
+ def pad_audio(self, audio: np.array) -> np.array:
85
+ r"""Pad the audio with zero in the end so that the length of audio can
86
+ be evenly divided by segment_samples.
87
+
88
+ Args:
89
+ audio: (channels_num, audio_samples)
90
+
91
+ Returns:
92
+ padded_audio: (channels_num, audio_samples)
93
+ """
94
+ channels_num, audio_samples = audio.shape
95
+
96
+ # Number of segments
97
+ segments_num = int(np.ceil(audio_samples / self.segment_samples))
98
+
99
+ pad_samples = segments_num * self.segment_samples - audio_samples
100
+
101
+ padded_audio = np.concatenate(
102
+ (audio, np.zeros((channels_num, pad_samples))), axis=1
103
+ )
104
+ # (channels_num, padded_audio_samples)
105
+
106
+ return padded_audio
107
+
108
+ def enframe(self, audio: np.array, segment_samples: int) -> np.array:
109
+ r"""Enframe long audio into segments.
110
+
111
+ Args:
112
+ audio: (channels_num, audio_samples)
113
+ segment_samples: int
114
+
115
+ Returns:
116
+ segments: (segments_num, channels_num, segment_samples)
117
+ """
118
+ audio_samples = audio.shape[1]
119
+ assert audio_samples % segment_samples == 0
120
+
121
+ hop_samples = segment_samples // 2
122
+ segments = []
123
+
124
+ pointer = 0
125
+ while pointer + segment_samples <= audio_samples:
126
+ segments.append(audio[:, pointer : pointer + segment_samples])
127
+ pointer += hop_samples
128
+
129
+ segments = np.array(segments)
130
+
131
+ return segments
132
+
133
+ def deframe(self, segments: np.array) -> np.array:
134
+ r"""Deframe segments into long audio.
135
+
136
+ Args:
137
+ segments: (segments_num, channels_num, segment_samples)
138
+
139
+ Returns:
140
+ output: (channels_num, audio_samples)
141
+ """
142
+ (segments_num, _, segment_samples) = segments.shape
143
+
144
+ if segments_num == 1:
145
+ return segments[0]
146
+
147
+ assert self._is_integer(segment_samples * 0.25)
148
+ assert self._is_integer(segment_samples * 0.75)
149
+
150
+ output = []
151
+
152
+ output.append(segments[0, :, 0 : int(segment_samples * 0.75)])
153
+
154
+ for i in range(1, segments_num - 1):
155
+ output.append(
156
+ segments[
157
+ i, :, int(segment_samples * 0.25) : int(segment_samples * 0.75)
158
+ ]
159
+ )
160
+
161
+ output.append(segments[-1, :, int(segment_samples * 0.25) :])
162
+
163
+ output = np.concatenate(output, axis=-1)
164
+
165
+ return output
166
+
167
+ def _is_integer(self, x: float) -> bool:
168
+ if x - int(x) < 1e-10:
169
+ return True
170
+ else:
171
+ return False
172
+
173
+ def _forward_in_mini_batches(
174
+ self, model: nn.Module, segments_input_dict: Dict, batch_size: int
175
+ ) -> Dict:
176
+ r"""Forward data to model in mini-batch.
177
+
178
+ Args:
179
+ model: nn.Module
180
+ segments_input_dict: dict, e.g., {
181
+ 'waveform': (segments_num, channels_num, segment_samples),
182
+ ...,
183
+ }
184
+ batch_size: int
185
+
186
+ Returns:
187
+ output_dict: dict, e.g. {
188
+ 'waveform': (segments_num, channels_num, segment_samples),
189
+ }
190
+ """
191
+ output_dict = {}
192
+
193
+ pointer = 0
194
+ segments_num = len(segments_input_dict['waveform'])
195
+
196
+ while True:
197
+ if pointer >= segments_num:
198
+ break
199
+
200
+ batch_input_dict = {}
201
+
202
+ for key in segments_input_dict.keys():
203
+ batch_input_dict[key] = torch.Tensor(
204
+ segments_input_dict[key][pointer : pointer + batch_size]
205
+ ).to(self.device)
206
+
207
+ pointer += batch_size
208
+
209
+ with torch.no_grad():
210
+ model.eval()
211
+ batch_output_dict = model(batch_input_dict)
212
+
213
+ for key in batch_output_dict.keys():
214
+ self._append_to_dict(
215
+ output_dict, key, batch_output_dict[key].data.cpu().numpy()
216
+ )
217
+
218
+ for key in output_dict.keys():
219
+ output_dict[key] = np.concatenate(output_dict[key], axis=0)
220
+
221
+ return output_dict
222
+
223
+ def _append_to_dict(self, dict, key, value):
224
+ if key in dict.keys():
225
+ dict[key].append(value)
226
+ else:
227
+ dict[key] = [value]
228
+
229
+
230
+ class SeparatorWrapper:
231
+ def __init__(
232
+ self, source_type='vocals', model=None, checkpoint_path=None, device='cuda'
233
+ ):
234
+
235
+ input_channels = 2
236
+ target_sources_num = 1
237
+ model_type = "ResUNet143_Subbandtime"
238
+ segment_samples = 44100 * 10
239
+ batch_size = 1
240
+
241
+ self.checkpoint_path = self.download_checkpoints(checkpoint_path, source_type)
242
+
243
+ if device == 'cuda' and torch.cuda.is_available():
244
+ self.device = 'cuda'
245
+ else:
246
+ self.device = 'cpu'
247
+
248
+ # Get model class.
249
+ Model = get_model_class(model_type)
250
+
251
+ # Create model.
252
+ self.model = Model(
253
+ input_channels=input_channels, target_sources_num=target_sources_num
254
+ )
255
+
256
+ # Load checkpoint.
257
+ checkpoint = torch.load(self.checkpoint_path, map_location='cpu')
258
+ self.model.load_state_dict(checkpoint["model"])
259
+
260
+ # Move model to device.
261
+ self.model.to(self.device)
262
+
263
+ # Create separator.
264
+ self.separator = Separator(
265
+ model=self.model,
266
+ segment_samples=segment_samples,
267
+ batch_size=batch_size,
268
+ device=self.device,
269
+ )
270
+
271
+ def download_checkpoints(self, checkpoint_path, source_type):
272
+
273
+ if source_type == "vocals":
274
+ checkpoint_bare_name = "resunet143_subbtandtime_vocals_8.8dB_350k_steps"
275
+
276
+ elif source_type == "accompaniment":
277
+ checkpoint_bare_name = (
278
+ "resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth"
279
+ )
280
+
281
+ else:
282
+ raise NotImplementedError
283
+
284
+ if not checkpoint_path:
285
+ checkpoint_path = '{}/bytesep_data/{}.pth'.format(
286
+ str(pathlib.Path.home()), checkpoint_bare_name
287
+ )
288
+
289
+ print('Checkpoint path: {}'.format(checkpoint_path))
290
+
291
+ if (
292
+ not os.path.exists(checkpoint_path)
293
+ or os.path.getsize(checkpoint_path) < 4e8
294
+ ):
295
+
296
+ os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
297
+
298
+ zenodo_dir = "https://zenodo.org/record/5507029/files"
299
+ zenodo_path = os.path.join(
300
+ zenodo_dir, "{}?download=1".format(checkpoint_bare_name)
301
+ )
302
+
303
+ os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))
304
+
305
+ return checkpoint_path
306
+
307
+ def separate(self, audio):
308
+
309
+ input_dict = {'waveform': audio}
310
+
311
+ sep_wav = self.separator.separate(input_dict)
312
+
313
+ return sep_wav
314
+
315
+
316
+ def inference(args):
317
+
318
+ # Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync.
319
+ import torch.distributed as dist
320
+
321
+ dist.init_process_group(
322
+ 'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1
323
+ )
324
+
325
+ # Arguments & parameters
326
+ config_yaml = args.config_yaml
327
+ checkpoint_path = args.checkpoint_path
328
+ audio_path = args.audio_path
329
+ output_path = args.output_path
330
+ device = (
331
+ torch.device('cuda')
332
+ if args.cuda and torch.cuda.is_available()
333
+ else torch.device('cpu')
334
+ )
335
+
336
+ configs = read_yaml(config_yaml)
337
+ sample_rate = configs['train']['sample_rate']
338
+ input_channels = configs['train']['channels']
339
+ target_source_types = configs['train']['target_source_types']
340
+ target_sources_num = len(target_source_types)
341
+ model_type = configs['train']['model_type']
342
+
343
+ segment_samples = int(30 * sample_rate)
344
+ batch_size = 1
345
+
346
+ print("Using {} for separating ..".format(device))
347
+
348
+ # paths
349
+ if os.path.dirname(output_path) != "":
350
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
351
+
352
+ # Get model class.
353
+ Model = get_model_class(model_type)
354
+
355
+ # Create model.
356
+ model = Model(input_channels=input_channels, target_sources_num=target_sources_num)
357
+
358
+ # Load checkpoint.
359
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
360
+ model.load_state_dict(checkpoint["model"])
361
+
362
+ # Move model to device.
363
+ model.to(device)
364
+
365
+ # Create separator.
366
+ separator = Separator(
367
+ model=model,
368
+ segment_samples=segment_samples,
369
+ batch_size=batch_size,
370
+ device=device,
371
+ )
372
+
373
+ # Load audio.
374
+ audio, _ = librosa.load(audio_path, sr=sample_rate, mono=False)
375
+
376
+ # audio = audio[None, :]
377
+
378
+ input_dict = {'waveform': audio}
379
+
380
+ # Separate
381
+ separate_time = time.time()
382
+
383
+ sep_wav = separator.separate(input_dict)
384
+ # (channels_num, audio_samples)
385
+
386
+ print('Separate time: {:.3f} s'.format(time.time() - separate_time))
387
+
388
+ # Write out separated audio.
389
+ soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate)
390
+ os.system("ffmpeg -y -loglevel panic -i _zz.wav {}".format(output_path))
391
+ print('Write out to {}'.format(output_path))
392
+
393
+
394
+ if __name__ == "__main__":
395
+
396
+ parser = argparse.ArgumentParser(description="")
397
+ parser.add_argument("--config_yaml", type=str, required=True)
398
+ parser.add_argument("--checkpoint_path", type=str, required=True)
399
+ parser.add_argument("--audio_path", type=str, required=True)
400
+ parser.add_argument("--output_path", type=str, required=True)
401
+ parser.add_argument("--cuda", action='store_true', default=True)
402
+
403
+ args = parser.parse_args()
404
+ inference(args)
bytesep/inference_many.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import time
5
+ from typing import NoReturn
6
+
7
+ import librosa
8
+ import numpy as np
9
+ import soundfile
10
+ import torch
11
+
12
+ from bytesep.inference import Separator
13
+ from bytesep.models.lightning_modules import get_model_class
14
+ from bytesep.utils import read_yaml
15
+
16
+
17
+ def inference(args) -> NoReturn:
18
+ r"""Separate all audios in a directory.
19
+
20
+ Args:
21
+ config_yaml: str, the config file of a model being trained
22
+ checkpoint_path: str, the path of checkpoint to be loaded
23
+ audios_dir: str, the directory of audios to be separated
24
+ output_dir: str, the directory to write out separated audios
25
+ scale_volume: bool, if True then the volume is scaled to the maximum value of 1.
26
+
27
+ Returns:
28
+ NoReturn
29
+ """
30
+
31
+ # Arguments & parameters
32
+ config_yaml = args.config_yaml
33
+ checkpoint_path = args.checkpoint_path
34
+ audios_dir = args.audios_dir
35
+ output_dir = args.output_dir
36
+ scale_volume = args.scale_volume
37
+ device = (
38
+ torch.device('cuda')
39
+ if args.cuda and torch.cuda.is_available()
40
+ else torch.device('cpu')
41
+ )
42
+
43
+ configs = read_yaml(config_yaml)
44
+ sample_rate = configs['train']['sample_rate']
45
+ input_channels = configs['train']['channels']
46
+ target_source_types = configs['train']['target_source_types']
47
+ target_sources_num = len(target_source_types)
48
+ model_type = configs['train']['model_type']
49
+ mono = input_channels == 1
50
+
51
+ segment_samples = int(30 * sample_rate)
52
+ batch_size = 1
53
+ device = "cuda"
54
+
55
+ models_contains_inplaceabn = True
56
+
57
+ # Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync.
58
+ if models_contains_inplaceabn:
59
+
60
+ import torch.distributed as dist
61
+
62
+ dist.init_process_group(
63
+ 'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1
64
+ )
65
+
66
+ print("Using {} for separating ..".format(device))
67
+
68
+ # paths
69
+ os.makedirs(output_dir, exist_ok=True)
70
+
71
+ # Get model class.
72
+ Model = get_model_class(model_type)
73
+
74
+ # Create model.
75
+ model = Model(input_channels=input_channels, target_sources_num=target_sources_num)
76
+
77
+ # Load checkpoint.
78
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
79
+ model.load_state_dict(checkpoint["model"])
80
+
81
+ # Move model to device.
82
+ model.to(device)
83
+
84
+ # Create separator.
85
+ separator = Separator(
86
+ model=model,
87
+ segment_samples=segment_samples,
88
+ batch_size=batch_size,
89
+ device=device,
90
+ )
91
+
92
+ audio_names = sorted(os.listdir(audios_dir))
93
+
94
+ for audio_name in audio_names:
95
+ audio_path = os.path.join(audios_dir, audio_name)
96
+
97
+ # Load audio.
98
+ audio, _ = librosa.load(audio_path, sr=sample_rate, mono=mono)
99
+
100
+ if audio.ndim == 1:
101
+ audio = audio[None, :]
102
+
103
+ input_dict = {'waveform': audio}
104
+
105
+ # Separate
106
+ separate_time = time.time()
107
+
108
+ sep_wav = separator.separate(input_dict)
109
+ # (channels_num, audio_samples)
110
+
111
+ print('Separate time: {:.3f} s'.format(time.time() - separate_time))
112
+
113
+ # Write out separated audio.
114
+ if scale_volume:
115
+ sep_wav /= np.max(np.abs(sep_wav))
116
+
117
+ soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate)
118
+
119
+ output_path = os.path.join(
120
+ output_dir, '{}.mp3'.format(pathlib.Path(audio_name).stem)
121
+ )
122
+ os.system('ffmpeg -y -loglevel panic -i _zz.wav "{}"'.format(output_path))
123
+ print('Write out to {}'.format(output_path))
124
+
125
+
126
+ if __name__ == "__main__":
127
+
128
+ parser = argparse.ArgumentParser(description="")
129
+ parser.add_argument(
130
+ "--config_yaml",
131
+ type=str,
132
+ required=True,
133
+ help="The config file of a model being trained.",
134
+ )
135
+ parser.add_argument(
136
+ "--checkpoint_path",
137
+ type=str,
138
+ required=True,
139
+ help="The path of checkpoint to be loaded.",
140
+ )
141
+ parser.add_argument(
142
+ "--audios_dir",
143
+ type=str,
144
+ required=True,
145
+ help="The directory of audios to be separated.",
146
+ )
147
+ parser.add_argument(
148
+ "--output_dir",
149
+ type=str,
150
+ required=True,
151
+ help="The directory to write out separated audios.",
152
+ )
153
+ parser.add_argument(
154
+ '--scale_volume',
155
+ action='store_true',
156
+ default=False,
157
+ help="set to True if separated audios are scaled to the maximum value of 1.",
158
+ )
159
+ parser.add_argument("--cuda", action='store_true', default=True)
160
+
161
+ args = parser.parse_args()
162
+
163
+ inference(args)
bytesep/losses.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchlibrosa.stft import STFT
7
+
8
+ from bytesep.models.pytorch_modules import Base
9
+
10
+
11
+ def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
12
+ r"""L1 loss.
13
+
14
+ Args:
15
+ output: torch.Tensor
16
+ target: torch.Tensor
17
+
18
+ Returns:
19
+ loss: torch.float
20
+ """
21
+ return torch.mean(torch.abs(output - target))
22
+
23
+
24
+ def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
25
+ r"""L1 loss in the time-domain.
26
+
27
+ Args:
28
+ output: torch.Tensor
29
+ target: torch.Tensor
30
+
31
+ Returns:
32
+ loss: torch.float
33
+ """
34
+ return l1(output, target)
35
+
36
+
37
+ class L1_Wav_L1_Sp(nn.Module, Base):
38
+ def __init__(self):
39
+ r"""L1 loss in the time-domain and L1 loss on the spectrogram."""
40
+ super(L1_Wav_L1_Sp, self).__init__()
41
+
42
+ self.window_size = 2048
43
+ hop_size = 441
44
+ center = True
45
+ pad_mode = "reflect"
46
+ window = "hann"
47
+
48
+ self.stft = STFT(
49
+ n_fft=self.window_size,
50
+ hop_length=hop_size,
51
+ win_length=self.window_size,
52
+ window=window,
53
+ center=center,
54
+ pad_mode=pad_mode,
55
+ freeze_parameters=True,
56
+ )
57
+
58
+ def __call__(
59
+ self, output: torch.Tensor, target: torch.Tensor, **kwargs
60
+ ) -> torch.Tensor:
61
+ r"""L1 loss in the time-domain and on the spectrogram.
62
+
63
+ Args:
64
+ output: torch.Tensor
65
+ target: torch.Tensor
66
+
67
+ Returns:
68
+ loss: torch.float
69
+ """
70
+
71
+ # L1 loss in the time-domain.
72
+ wav_loss = l1_wav(output, target)
73
+
74
+ # L1 loss on the spectrogram.
75
+ sp_loss = l1(
76
+ self.wav_to_spectrogram(output, eps=1e-8),
77
+ self.wav_to_spectrogram(target, eps=1e-8),
78
+ )
79
+
80
+ # sp_loss /= math.sqrt(self.window_size)
81
+ # sp_loss *= 1.
82
+
83
+ # Total loss.
84
+ return wav_loss + sp_loss
85
+
86
+ return sp_loss
87
+
88
+
89
+ def get_loss_function(loss_type: str) -> Callable:
90
+ r"""Get loss function.
91
+
92
+ Args:
93
+ loss_type: str
94
+
95
+ Returns:
96
+ loss function: Callable
97
+ """
98
+
99
+ if loss_type == "l1_wav":
100
+ return l1_wav
101
+
102
+ elif loss_type == "l1_wav_l1_sp":
103
+ return L1_Wav_L1_Sp()
104
+
105
+ else:
106
+ raise NotImplementedError
bytesep/models/__init__.py ADDED
File without changes
bytesep/models/conditional_unet.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from torchlibrosa.stft import STFT, ISTFT, magphase
13
+
14
+ from bytesep.models.pytorch_modules import (
15
+ Base,
16
+ init_bn,
17
+ init_embedding,
18
+ init_layer,
19
+ act,
20
+ Subband,
21
+ )
22
+
23
+
24
+ class ConvBlock(nn.Module):
25
+ def __init__(
26
+ self,
27
+ in_channels,
28
+ out_channels,
29
+ condition_size,
30
+ kernel_size,
31
+ activation,
32
+ momentum,
33
+ ):
34
+ super(ConvBlock, self).__init__()
35
+
36
+ self.activation = activation
37
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
38
+
39
+ self.conv1 = nn.Conv2d(
40
+ in_channels=in_channels,
41
+ out_channels=out_channels,
42
+ kernel_size=kernel_size,
43
+ stride=(1, 1),
44
+ dilation=(1, 1),
45
+ padding=padding,
46
+ bias=False,
47
+ )
48
+
49
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
50
+
51
+ self.conv2 = nn.Conv2d(
52
+ in_channels=out_channels,
53
+ out_channels=out_channels,
54
+ kernel_size=kernel_size,
55
+ stride=(1, 1),
56
+ dilation=(1, 1),
57
+ padding=padding,
58
+ bias=False,
59
+ )
60
+
61
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
62
+
63
+ self.beta1 = nn.Linear(condition_size, out_channels, bias=True)
64
+ self.beta2 = nn.Linear(condition_size, out_channels, bias=True)
65
+
66
+ self.init_weights()
67
+
68
+ def init_weights(self):
69
+ init_layer(self.conv1)
70
+ init_layer(self.conv2)
71
+ init_bn(self.bn1)
72
+ init_bn(self.bn2)
73
+ init_embedding(self.beta1)
74
+ init_embedding(self.beta2)
75
+
76
+ def forward(self, x, condition):
77
+
78
+ b1 = self.beta1(condition)[:, :, None, None]
79
+ b2 = self.beta2(condition)[:, :, None, None]
80
+
81
+ x = act(self.bn1(self.conv1(x)) + b1, self.activation)
82
+ x = act(self.bn2(self.conv2(x)) + b2, self.activation)
83
+ return x
84
+
85
+
86
+ class EncoderBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels,
90
+ out_channels,
91
+ condition_size,
92
+ kernel_size,
93
+ downsample,
94
+ activation,
95
+ momentum,
96
+ ):
97
+ super(EncoderBlock, self).__init__()
98
+
99
+ self.conv_block = ConvBlock(
100
+ in_channels, out_channels, condition_size, kernel_size, activation, momentum
101
+ )
102
+ self.downsample = downsample
103
+
104
+ def forward(self, x, condition):
105
+ encoder = self.conv_block(x, condition)
106
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
107
+ return encoder_pool, encoder
108
+
109
+
110
+ class DecoderBlock(nn.Module):
111
+ def __init__(
112
+ self,
113
+ in_channels,
114
+ out_channels,
115
+ condition_size,
116
+ kernel_size,
117
+ upsample,
118
+ activation,
119
+ momentum,
120
+ ):
121
+ super(DecoderBlock, self).__init__()
122
+ self.kernel_size = kernel_size
123
+ self.stride = upsample
124
+ self.activation = activation
125
+
126
+ self.conv1 = torch.nn.ConvTranspose2d(
127
+ in_channels=in_channels,
128
+ out_channels=out_channels,
129
+ kernel_size=self.stride,
130
+ stride=self.stride,
131
+ padding=(0, 0),
132
+ bias=False,
133
+ dilation=(1, 1),
134
+ )
135
+
136
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
137
+
138
+ self.conv_block2 = ConvBlock(
139
+ out_channels * 2,
140
+ out_channels,
141
+ condition_size,
142
+ kernel_size,
143
+ activation,
144
+ momentum,
145
+ )
146
+
147
+ self.beta1 = nn.Linear(condition_size, out_channels, bias=True)
148
+
149
+ self.init_weights()
150
+
151
+ def init_weights(self):
152
+ init_layer(self.conv1)
153
+ init_bn(self.bn1)
154
+ init_embedding(self.beta1)
155
+
156
+ def forward(self, input_tensor, concat_tensor, condition):
157
+ b1 = self.beta1(condition)[:, :, None, None]
158
+ x = act(self.bn1(self.conv1(input_tensor)) + b1, self.activation)
159
+ x = torch.cat((x, concat_tensor), dim=1)
160
+ x = self.conv_block2(x, condition)
161
+ return x
162
+
163
+
164
+ class ConditionalUNet(nn.Module, Base):
165
+ def __init__(self, input_channels, target_sources_num):
166
+ super(ConditionalUNet, self).__init__()
167
+
168
+ self.input_channels = input_channels
169
+ condition_size = target_sources_num
170
+ self.output_sources_num = 1
171
+
172
+ window_size = 2048
173
+ hop_size = 441
174
+ center = True
175
+ pad_mode = "reflect"
176
+ window = "hann"
177
+ activation = "relu"
178
+ momentum = 0.01
179
+
180
+ self.subbands_num = 4
181
+ self.K = 3 # outputs: |M|, cos∠M, sin∠M
182
+
183
+ self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
184
+
185
+ self.stft = STFT(
186
+ n_fft=window_size,
187
+ hop_length=hop_size,
188
+ win_length=window_size,
189
+ window=window,
190
+ center=center,
191
+ pad_mode=pad_mode,
192
+ freeze_parameters=True,
193
+ )
194
+
195
+ self.istft = ISTFT(
196
+ n_fft=window_size,
197
+ hop_length=hop_size,
198
+ win_length=window_size,
199
+ window=window,
200
+ center=center,
201
+ pad_mode=pad_mode,
202
+ freeze_parameters=True,
203
+ )
204
+
205
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
206
+
207
+ self.subband = Subband(subbands_num=self.subbands_num)
208
+
209
+ self.encoder_block1 = EncoderBlock(
210
+ in_channels=input_channels * self.subbands_num,
211
+ out_channels=32,
212
+ condition_size=condition_size,
213
+ kernel_size=(3, 3),
214
+ downsample=(2, 2),
215
+ activation=activation,
216
+ momentum=momentum,
217
+ )
218
+ self.encoder_block2 = EncoderBlock(
219
+ in_channels=32,
220
+ out_channels=64,
221
+ condition_size=condition_size,
222
+ kernel_size=(3, 3),
223
+ downsample=(2, 2),
224
+ activation=activation,
225
+ momentum=momentum,
226
+ )
227
+ self.encoder_block3 = EncoderBlock(
228
+ in_channels=64,
229
+ out_channels=128,
230
+ condition_size=condition_size,
231
+ kernel_size=(3, 3),
232
+ downsample=(2, 2),
233
+ activation=activation,
234
+ momentum=momentum,
235
+ )
236
+ self.encoder_block4 = EncoderBlock(
237
+ in_channels=128,
238
+ out_channels=256,
239
+ condition_size=condition_size,
240
+ kernel_size=(3, 3),
241
+ downsample=(2, 2),
242
+ activation=activation,
243
+ momentum=momentum,
244
+ )
245
+ self.encoder_block5 = EncoderBlock(
246
+ in_channels=256,
247
+ out_channels=384,
248
+ condition_size=condition_size,
249
+ kernel_size=(3, 3),
250
+ downsample=(2, 2),
251
+ activation=activation,
252
+ momentum=momentum,
253
+ )
254
+ self.encoder_block6 = EncoderBlock(
255
+ in_channels=384,
256
+ out_channels=384,
257
+ condition_size=condition_size,
258
+ kernel_size=(3, 3),
259
+ downsample=(2, 2),
260
+ activation=activation,
261
+ momentum=momentum,
262
+ )
263
+ self.conv_block7 = ConvBlock(
264
+ in_channels=384,
265
+ out_channels=384,
266
+ condition_size=condition_size,
267
+ kernel_size=(3, 3),
268
+ activation=activation,
269
+ momentum=momentum,
270
+ )
271
+ self.decoder_block1 = DecoderBlock(
272
+ in_channels=384,
273
+ out_channels=384,
274
+ condition_size=condition_size,
275
+ kernel_size=(3, 3),
276
+ upsample=(2, 2),
277
+ activation=activation,
278
+ momentum=momentum,
279
+ )
280
+ self.decoder_block2 = DecoderBlock(
281
+ in_channels=384,
282
+ out_channels=384,
283
+ condition_size=condition_size,
284
+ kernel_size=(3, 3),
285
+ upsample=(2, 2),
286
+ activation=activation,
287
+ momentum=momentum,
288
+ )
289
+ self.decoder_block3 = DecoderBlock(
290
+ in_channels=384,
291
+ out_channels=256,
292
+ condition_size=condition_size,
293
+ kernel_size=(3, 3),
294
+ upsample=(2, 2),
295
+ activation=activation,
296
+ momentum=momentum,
297
+ )
298
+ self.decoder_block4 = DecoderBlock(
299
+ in_channels=256,
300
+ out_channels=128,
301
+ condition_size=condition_size,
302
+ kernel_size=(3, 3),
303
+ upsample=(2, 2),
304
+ activation=activation,
305
+ momentum=momentum,
306
+ )
307
+ self.decoder_block5 = DecoderBlock(
308
+ in_channels=128,
309
+ out_channels=64,
310
+ condition_size=condition_size,
311
+ kernel_size=(3, 3),
312
+ upsample=(2, 2),
313
+ activation=activation,
314
+ momentum=momentum,
315
+ )
316
+ self.decoder_block6 = DecoderBlock(
317
+ in_channels=64,
318
+ out_channels=32,
319
+ condition_size=condition_size,
320
+ kernel_size=(3, 3),
321
+ upsample=(2, 2),
322
+ activation=activation,
323
+ momentum=momentum,
324
+ )
325
+
326
+ self.after_conv_block1 = ConvBlock(
327
+ in_channels=32,
328
+ out_channels=32,
329
+ condition_size=condition_size,
330
+ kernel_size=(3, 3),
331
+ activation=activation,
332
+ momentum=momentum,
333
+ )
334
+
335
+ self.after_conv2 = nn.Conv2d(
336
+ in_channels=32,
337
+ out_channels=input_channels
338
+ * self.subbands_num
339
+ * self.output_sources_num
340
+ * self.K,
341
+ kernel_size=(1, 1),
342
+ stride=(1, 1),
343
+ padding=(0, 0),
344
+ bias=True,
345
+ )
346
+
347
+ self.init_weights()
348
+
349
+ def init_weights(self):
350
+ init_bn(self.bn0)
351
+ init_layer(self.after_conv2)
352
+
353
+ def feature_maps_to_wav(self, x, sp, sin_in, cos_in, audio_length):
354
+
355
+ batch_size, _, time_steps, freq_bins = x.shape
356
+
357
+ x = x.reshape(
358
+ batch_size,
359
+ self.output_sources_num,
360
+ self.input_channels,
361
+ self.K,
362
+ time_steps,
363
+ freq_bins,
364
+ )
365
+ # x: (batch_size, output_sources_num, input_channles, K, time_steps, freq_bins)
366
+
367
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
368
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
369
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
370
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
371
+ # mask_cos, mask_sin: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
372
+
373
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
374
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
375
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
376
+ out_cos = (
377
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
378
+ )
379
+ out_sin = (
380
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
381
+ )
382
+ # out_cos, out_sin: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
383
+
384
+ # Calculate |Y|.
385
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
386
+ # out_mag: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
387
+
388
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
389
+ out_real = out_mag * out_cos
390
+ out_imag = out_mag * out_sin
391
+ # out_real, out_imag: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
392
+
393
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
394
+ shape = (
395
+ batch_size * self.output_sources_num * self.input_channels,
396
+ 1,
397
+ time_steps,
398
+ freq_bins,
399
+ )
400
+ out_real = out_real.reshape(shape)
401
+ out_imag = out_imag.reshape(shape)
402
+
403
+ # ISTFT.
404
+ wav_out = self.istft(out_real, out_imag, audio_length)
405
+ # (batch_size * output_sources_num * input_channels, segments_num)
406
+
407
+ # Reshape.
408
+ wav_out = wav_out.reshape(
409
+ batch_size, self.output_sources_num * self.input_channels, audio_length
410
+ )
411
+ # (batch_size, output_sources_num * input_channels, segments_num)
412
+
413
+ return wav_out
414
+
415
+ def forward(self, input_dict):
416
+ """
417
+ Args:
418
+ input: (batch_size, segment_samples, channels_num)
419
+
420
+ Outputs:
421
+ output_dict: {
422
+ 'wav': (batch_size, segment_samples, channels_num),
423
+ 'sp': (batch_size, channels_num, time_steps, freq_bins)}
424
+ """
425
+
426
+ mixture = input_dict['waveform']
427
+ condition = input_dict['condition']
428
+
429
+ sp, cos_in, sin_in = self.wav_to_spectrogram_phase(mixture)
430
+ """(batch_size, channels_num, time_steps, freq_bins)"""
431
+
432
+ # Batch normalization
433
+ x = sp.transpose(1, 3)
434
+ x = self.bn0(x)
435
+ x = x.transpose(1, 3)
436
+ """(batch_size, chanenls, time_steps, freq_bins)"""
437
+
438
+ # Pad spectrogram to be evenly divided by downsample ratio.
439
+ origin_len = x.shape[2]
440
+ pad_len = (
441
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
442
+ - origin_len
443
+ )
444
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
445
+ """(batch_size, channels, padded_time_steps, freq_bins)"""
446
+
447
+ # Let frequency bins be evenly divided by 2, e.g., 513 -> 512
448
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
449
+
450
+ x = self.subband.analysis(x)
451
+
452
+ # UNet
453
+ (x1_pool, x1) = self.encoder_block1(
454
+ x, condition
455
+ ) # x1_pool: (bs, 32, T / 2, F / 2)
456
+ (x2_pool, x2) = self.encoder_block2(
457
+ x1_pool, condition
458
+ ) # x2_pool: (bs, 64, T / 4, F / 4)
459
+ (x3_pool, x3) = self.encoder_block3(
460
+ x2_pool, condition
461
+ ) # x3_pool: (bs, 128, T / 8, F / 8)
462
+ (x4_pool, x4) = self.encoder_block4(
463
+ x3_pool, condition
464
+ ) # x4_pool: (bs, 256, T / 16, F / 16)
465
+ (x5_pool, x5) = self.encoder_block5(
466
+ x4_pool, condition
467
+ ) # x5_pool: (bs, 512, T / 32, F / 32)
468
+ (x6_pool, x6) = self.encoder_block6(
469
+ x5_pool, condition
470
+ ) # x6_pool: (bs, 1024, T / 64, F / 64)
471
+ x_center = self.conv_block7(x6_pool, condition) # (bs, 2048, T / 64, F / 64)
472
+ x7 = self.decoder_block1(x_center, x6, condition) # (bs, 1024, T / 32, F / 32)
473
+ x8 = self.decoder_block2(x7, x5, condition) # (bs, 512, T / 16, F / 16)
474
+ x9 = self.decoder_block3(x8, x4, condition) # (bs, 256, T / 8, F / 8)
475
+ x10 = self.decoder_block4(x9, x3, condition) # (bs, 128, T / 4, F / 4)
476
+ x11 = self.decoder_block5(x10, x2, condition) # (bs, 64, T / 2, F / 2)
477
+ x12 = self.decoder_block6(x11, x1, condition) # (bs, 32, T, F)
478
+ x = self.after_conv_block1(x12, condition) # (bs, 32, T, F)
479
+ x = self.after_conv2(x)
480
+ # (batch_size, input_channles * subbands_num * targets_num * k, T, F // subbands_num)
481
+
482
+ x = self.subband.synthesis(x)
483
+ # (batch_size, input_channles * targets_num * K, T, F)
484
+
485
+ # Recover shape
486
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
487
+ x = x[:, :, 0:origin_len, :] # (bs, feature_maps, T, F)
488
+
489
+ audio_length = mixture.shape[2]
490
+
491
+ separated_audio = self.feature_maps_to_wav(x, sp, sin_in, cos_in, audio_length)
492
+ # separated_audio: (batch_size, output_sources_num * input_channels, segments_num)
493
+
494
+ output_dict = {'waveform': separated_audio}
495
+
496
+ return output_dict
bytesep/models/lightning_modules.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict
2
+
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+
9
+
10
+ class LitSourceSeparation(pl.LightningModule):
11
+ def __init__(
12
+ self,
13
+ batch_data_preprocessor,
14
+ model: nn.Module,
15
+ loss_function: Callable,
16
+ optimizer_type: str,
17
+ learning_rate: float,
18
+ lr_lambda: Callable,
19
+ ):
20
+ r"""Pytorch Lightning wrapper of PyTorch model, including forward,
21
+ optimization of model, etc.
22
+
23
+ Args:
24
+ batch_data_preprocessor: object, used for preparing inputs and
25
+ targets for training. E.g., BasicBatchDataPreprocessor is used
26
+ for preparing data in dictionary into tensor.
27
+ model: nn.Module
28
+ loss_function: function
29
+ learning_rate: float
30
+ lr_lambda: function
31
+ """
32
+ super().__init__()
33
+
34
+ self.batch_data_preprocessor = batch_data_preprocessor
35
+ self.model = model
36
+ self.optimizer_type = optimizer_type
37
+ self.loss_function = loss_function
38
+ self.learning_rate = learning_rate
39
+ self.lr_lambda = lr_lambda
40
+
41
+ def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float:
42
+ r"""Forward a mini-batch data to model, calculate loss function, and
43
+ train for one step. A mini-batch data is evenly distributed to multiple
44
+ devices (if there are) for parallel training.
45
+
46
+ Args:
47
+ batch_data_dict: e.g. {
48
+ 'vocals': (batch_size, channels_num, segment_samples),
49
+ 'accompaniment': (batch_size, channels_num, segment_samples),
50
+ 'mixture': (batch_size, channels_num, segment_samples)
51
+ }
52
+ batch_idx: int
53
+
54
+ Returns:
55
+ loss: float, loss function of this mini-batch
56
+ """
57
+ input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict)
58
+ # input_dict: {
59
+ # 'waveform': (batch_size, channels_num, segment_samples),
60
+ # (if_exist) 'condition': (batch_size, channels_num),
61
+ # }
62
+ # target_dict: {
63
+ # 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
64
+ # }
65
+
66
+ # Forward.
67
+ self.model.train()
68
+
69
+ output_dict = self.model(input_dict)
70
+ # output_dict: {
71
+ # 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
72
+ # }
73
+
74
+ outputs = output_dict['waveform']
75
+ # outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples)
76
+
77
+ # Calculate loss.
78
+ loss = self.loss_function(
79
+ output=outputs,
80
+ target=target_dict['waveform'],
81
+ mixture=input_dict['waveform'],
82
+ )
83
+
84
+ return loss
85
+
86
+ def configure_optimizers(self) -> Any:
87
+ r"""Configure optimizer."""
88
+
89
+ if self.optimizer_type == "Adam":
90
+ optimizer = optim.Adam(
91
+ self.model.parameters(),
92
+ lr=self.learning_rate,
93
+ betas=(0.9, 0.999),
94
+ eps=1e-08,
95
+ weight_decay=0.0,
96
+ amsgrad=True,
97
+ )
98
+
99
+ elif self.optimizer_type == "AdamW":
100
+ optimizer = optim.AdamW(
101
+ self.model.parameters(),
102
+ lr=self.learning_rate,
103
+ betas=(0.9, 0.999),
104
+ eps=1e-08,
105
+ weight_decay=0.0,
106
+ amsgrad=True,
107
+ )
108
+
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ scheduler = {
113
+ 'scheduler': LambdaLR(optimizer, self.lr_lambda),
114
+ 'interval': 'step',
115
+ 'frequency': 1,
116
+ }
117
+
118
+ return [optimizer], [scheduler]
119
+
120
+
121
+ def get_model_class(model_type):
122
+ r"""Get model.
123
+
124
+ Args:
125
+ model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN'
126
+
127
+ Returns:
128
+ nn.Module
129
+ """
130
+ if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021':
131
+ from bytesep.models.resunet_ismir2021 import (
132
+ ResUNet143_DecouplePlusInplaceABN_ISMIR2021,
133
+ )
134
+
135
+ return ResUNet143_DecouplePlusInplaceABN_ISMIR2021
136
+
137
+ elif model_type == 'UNet':
138
+ from bytesep.models.unet import UNet
139
+
140
+ return UNet
141
+
142
+ elif model_type == 'UNetSubbandTime':
143
+ from bytesep.models.unet_subbandtime import UNetSubbandTime
144
+
145
+ return UNetSubbandTime
146
+
147
+ elif model_type == 'ResUNet143_Subbandtime':
148
+ from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime
149
+
150
+ return ResUNet143_Subbandtime
151
+
152
+ elif model_type == 'ResUNet143_DecouplePlus':
153
+ from bytesep.models.resunet import ResUNet143_DecouplePlus
154
+
155
+ return ResUNet143_DecouplePlus
156
+
157
+ elif model_type == 'ConditionalUNet':
158
+ from bytesep.models.conditional_unet import ConditionalUNet
159
+
160
+ return ConditionalUNet
161
+
162
+ elif model_type == 'LevelRNN':
163
+ from bytesep.models.levelrnn import LevelRNN
164
+
165
+ return LevelRNN
166
+
167
+ elif model_type == 'WavUNet':
168
+ from bytesep.models.wavunet import WavUNet
169
+
170
+ return WavUNet
171
+
172
+ elif model_type == 'WavUNetLevelRNN':
173
+ from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN
174
+
175
+ return WavUNetLevelRNN
176
+
177
+ elif model_type == 'TTnet':
178
+ from bytesep.models.ttnet import TTnet
179
+
180
+ return TTnet
181
+
182
+ elif model_type == 'TTnetNoTransformer':
183
+ from bytesep.models.ttnet_no_transformer import TTnetNoTransformer
184
+
185
+ return TTnetNoTransformer
186
+
187
+ else:
188
+ raise NotImplementedError
bytesep/models/pytorch_modules.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, NoReturn
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def init_embedding(layer: nn.Module) -> NoReturn:
10
+ r"""Initialize a Linear or Convolutional layer."""
11
+ nn.init.uniform_(layer.weight, -1.0, 1.0)
12
+
13
+ if hasattr(layer, 'bias'):
14
+ if layer.bias is not None:
15
+ layer.bias.data.fill_(0.0)
16
+
17
+
18
+ def init_layer(layer: nn.Module) -> NoReturn:
19
+ r"""Initialize a Linear or Convolutional layer."""
20
+ nn.init.xavier_uniform_(layer.weight)
21
+
22
+ if hasattr(layer, "bias"):
23
+ if layer.bias is not None:
24
+ layer.bias.data.fill_(0.0)
25
+
26
+
27
+ def init_bn(bn: nn.Module) -> NoReturn:
28
+ r"""Initialize a Batchnorm layer."""
29
+ bn.bias.data.fill_(0.0)
30
+ bn.weight.data.fill_(1.0)
31
+ bn.running_mean.data.fill_(0.0)
32
+ bn.running_var.data.fill_(1.0)
33
+
34
+
35
+ def act(x: torch.Tensor, activation: str) -> torch.Tensor:
36
+
37
+ if activation == "relu":
38
+ return F.relu_(x)
39
+
40
+ elif activation == "leaky_relu":
41
+ return F.leaky_relu_(x, negative_slope=0.01)
42
+
43
+ elif activation == "swish":
44
+ return x * torch.sigmoid(x)
45
+
46
+ else:
47
+ raise Exception("Incorrect activation!")
48
+
49
+
50
+ class Base:
51
+ def __init__(self):
52
+ r"""Base function for extracting spectrogram, cos, and sin, etc."""
53
+ pass
54
+
55
+ def spectrogram(self, input: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
56
+ r"""Calculate spectrogram.
57
+
58
+ Args:
59
+ input: (batch_size, segments_num)
60
+ eps: float
61
+
62
+ Returns:
63
+ spectrogram: (batch_size, time_steps, freq_bins)
64
+ """
65
+ (real, imag) = self.stft(input)
66
+ return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
67
+
68
+ def spectrogram_phase(
69
+ self, input: torch.Tensor, eps: float = 0.0
70
+ ) -> List[torch.Tensor]:
71
+ r"""Calculate the magnitude, cos, and sin of the STFT of input.
72
+
73
+ Args:
74
+ input: (batch_size, segments_num)
75
+ eps: float
76
+
77
+ Returns:
78
+ mag: (batch_size, time_steps, freq_bins)
79
+ cos: (batch_size, time_steps, freq_bins)
80
+ sin: (batch_size, time_steps, freq_bins)
81
+ """
82
+ (real, imag) = self.stft(input)
83
+ mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
84
+ cos = real / mag
85
+ sin = imag / mag
86
+ return mag, cos, sin
87
+
88
+ def wav_to_spectrogram_phase(
89
+ self, input: torch.Tensor, eps: float = 1e-10
90
+ ) -> List[torch.Tensor]:
91
+ r"""Convert waveforms to magnitude, cos, and sin of STFT.
92
+
93
+ Args:
94
+ input: (batch_size, channels_num, segment_samples)
95
+ eps: float
96
+
97
+ Outputs:
98
+ mag: (batch_size, channels_num, time_steps, freq_bins)
99
+ cos: (batch_size, channels_num, time_steps, freq_bins)
100
+ sin: (batch_size, channels_num, time_steps, freq_bins)
101
+ """
102
+ batch_size, channels_num, segment_samples = input.shape
103
+
104
+ # Reshape input with shapes of (n, segments_num) to meet the
105
+ # requirements of the stft function.
106
+ x = input.reshape(batch_size * channels_num, segment_samples)
107
+
108
+ mag, cos, sin = self.spectrogram_phase(x, eps=eps)
109
+ # mag, cos, sin: (batch_size * channels_num, 1, time_steps, freq_bins)
110
+
111
+ _, _, time_steps, freq_bins = mag.shape
112
+ mag = mag.reshape(batch_size, channels_num, time_steps, freq_bins)
113
+ cos = cos.reshape(batch_size, channels_num, time_steps, freq_bins)
114
+ sin = sin.reshape(batch_size, channels_num, time_steps, freq_bins)
115
+
116
+ return mag, cos, sin
117
+
118
+ def wav_to_spectrogram(
119
+ self, input: torch.Tensor, eps: float = 1e-10
120
+ ) -> List[torch.Tensor]:
121
+
122
+ mag, cos, sin = self.wav_to_spectrogram_phase(input, eps)
123
+ return mag
124
+
125
+
126
+ class Subband:
127
+ def __init__(self, subbands_num: int):
128
+ r"""Warning!! This class is not used!!
129
+
130
+ This class does not work as good as [1] which split subbands in the
131
+ time-domain. Please refere to [1] for formal implementation.
132
+
133
+ [1] Liu, Haohe, et al. "Channel-wise subband input for better voice and
134
+ accompaniment separation on high resolution music." arXiv preprint arXiv:2008.05216 (2020).
135
+
136
+ Args:
137
+ subbands_num: int, e.g., 4
138
+ """
139
+ self.subbands_num = subbands_num
140
+
141
+ def analysis(self, x: torch.Tensor) -> torch.Tensor:
142
+ r"""Analysis time-frequency representation into subbands. Stack the
143
+ subbands along the channel axis.
144
+
145
+ Args:
146
+ x: (batch_size, channels_num, time_steps, freq_bins)
147
+
148
+ Returns:
149
+ output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
150
+ """
151
+ batch_size, channels_num, time_steps, freq_bins = x.shape
152
+
153
+ x = x.reshape(
154
+ batch_size,
155
+ channels_num,
156
+ time_steps,
157
+ self.subbands_num,
158
+ freq_bins // self.subbands_num,
159
+ )
160
+ # x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)
161
+
162
+ x = x.transpose(2, 3)
163
+
164
+ output = x.reshape(
165
+ batch_size,
166
+ channels_num * self.subbands_num,
167
+ time_steps,
168
+ freq_bins // self.subbands_num,
169
+ )
170
+ # output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
171
+
172
+ return output
173
+
174
+ def synthesis(self, x: torch.Tensor) -> torch.Tensor:
175
+ r"""Synthesis subband time-frequency representations into original
176
+ time-frequency representation.
177
+
178
+ Args:
179
+ x: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
180
+
181
+ Returns:
182
+ output: (batch_size, channels_num, time_steps, freq_bins)
183
+ """
184
+ batch_size, subband_channels_num, time_steps, subband_freq_bins = x.shape
185
+
186
+ channels_num = subband_channels_num // self.subbands_num
187
+ freq_bins = subband_freq_bins * self.subbands_num
188
+
189
+ x = x.reshape(
190
+ batch_size,
191
+ channels_num,
192
+ self.subbands_num,
193
+ time_steps,
194
+ subband_freq_bins,
195
+ )
196
+ # x: (batch_size, channels_num, subbands_num, time_steps, freq_bins // subbands_num)
197
+
198
+ x = x.transpose(2, 3)
199
+ # x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)
200
+
201
+ output = x.reshape(batch_size, channels_num, time_steps, freq_bins)
202
+ # x: (batch_size, channels_num, time_steps, freq_bins)
203
+
204
+ return output
bytesep/models/resunet.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchlibrosa.stft import ISTFT, STFT, magphase
6
+
7
+ from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer
8
+
9
+
10
+ class ConvBlockRes(nn.Module):
11
+ def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
12
+ r"""Residual block."""
13
+ super(ConvBlockRes, self).__init__()
14
+
15
+ self.activation = activation
16
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
17
+
18
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
19
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
20
+
21
+ self.conv1 = nn.Conv2d(
22
+ in_channels=in_channels,
23
+ out_channels=out_channels,
24
+ kernel_size=kernel_size,
25
+ stride=(1, 1),
26
+ dilation=(1, 1),
27
+ padding=padding,
28
+ bias=False,
29
+ )
30
+
31
+ self.conv2 = nn.Conv2d(
32
+ in_channels=out_channels,
33
+ out_channels=out_channels,
34
+ kernel_size=kernel_size,
35
+ stride=(1, 1),
36
+ dilation=(1, 1),
37
+ padding=padding,
38
+ bias=False,
39
+ )
40
+
41
+ if in_channels != out_channels:
42
+ self.shortcut = nn.Conv2d(
43
+ in_channels=in_channels,
44
+ out_channels=out_channels,
45
+ kernel_size=(1, 1),
46
+ stride=(1, 1),
47
+ padding=(0, 0),
48
+ )
49
+
50
+ self.is_shortcut = True
51
+ else:
52
+ self.is_shortcut = False
53
+
54
+ self.init_weights()
55
+
56
+ def init_weights(self):
57
+ init_bn(self.bn1)
58
+ init_bn(self.bn2)
59
+ init_layer(self.conv1)
60
+ init_layer(self.conv2)
61
+
62
+ if self.is_shortcut:
63
+ init_layer(self.shortcut)
64
+
65
+ def forward(self, x):
66
+ origin = x
67
+ x = self.conv1(act(self.bn1(x), self.activation))
68
+ x = self.conv2(act(self.bn2(x), self.activation))
69
+
70
+ if self.is_shortcut:
71
+ return self.shortcut(origin) + x
72
+ else:
73
+ return origin + x
74
+
75
+
76
+ class EncoderBlockRes4B(nn.Module):
77
+ def __init__(
78
+ self, in_channels, out_channels, kernel_size, downsample, activation, momentum
79
+ ):
80
+ r"""Encoder block, contains 8 convolutional layers."""
81
+ super(EncoderBlockRes4B, self).__init__()
82
+
83
+ self.conv_block1 = ConvBlockRes(
84
+ in_channels, out_channels, kernel_size, activation, momentum
85
+ )
86
+ self.conv_block2 = ConvBlockRes(
87
+ out_channels, out_channels, kernel_size, activation, momentum
88
+ )
89
+ self.conv_block3 = ConvBlockRes(
90
+ out_channels, out_channels, kernel_size, activation, momentum
91
+ )
92
+ self.conv_block4 = ConvBlockRes(
93
+ out_channels, out_channels, kernel_size, activation, momentum
94
+ )
95
+ self.downsample = downsample
96
+
97
+ def forward(self, x):
98
+ encoder = self.conv_block1(x)
99
+ encoder = self.conv_block2(encoder)
100
+ encoder = self.conv_block3(encoder)
101
+ encoder = self.conv_block4(encoder)
102
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
103
+ return encoder_pool, encoder
104
+
105
+
106
+ class DecoderBlockRes4B(nn.Module):
107
+ def __init__(
108
+ self, in_channels, out_channels, kernel_size, upsample, activation, momentum
109
+ ):
110
+ r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
111
+ super(DecoderBlockRes4B, self).__init__()
112
+ self.kernel_size = kernel_size
113
+ self.stride = upsample
114
+ self.activation = activation
115
+
116
+ self.conv1 = torch.nn.ConvTranspose2d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=self.stride,
120
+ stride=self.stride,
121
+ padding=(0, 0),
122
+ bias=False,
123
+ dilation=(1, 1),
124
+ )
125
+
126
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
127
+ self.conv_block2 = ConvBlockRes(
128
+ out_channels * 2, out_channels, kernel_size, activation, momentum
129
+ )
130
+ self.conv_block3 = ConvBlockRes(
131
+ out_channels, out_channels, kernel_size, activation, momentum
132
+ )
133
+ self.conv_block4 = ConvBlockRes(
134
+ out_channels, out_channels, kernel_size, activation, momentum
135
+ )
136
+ self.conv_block5 = ConvBlockRes(
137
+ out_channels, out_channels, kernel_size, activation, momentum
138
+ )
139
+
140
+ self.init_weights()
141
+
142
+ def init_weights(self):
143
+ init_bn(self.bn1)
144
+ init_layer(self.conv1)
145
+
146
+ def forward(self, input_tensor, concat_tensor):
147
+ x = self.conv1(act(self.bn1(input_tensor), self.activation))
148
+ x = torch.cat((x, concat_tensor), dim=1)
149
+ x = self.conv_block2(x)
150
+ x = self.conv_block3(x)
151
+ x = self.conv_block4(x)
152
+ x = self.conv_block5(x)
153
+ return x
154
+
155
+
156
+ class ResUNet143_DecouplePlus(nn.Module, Base):
157
+ def __init__(self, input_channels, target_sources_num):
158
+ super(ResUNet143_DecouplePlus, self).__init__()
159
+
160
+ self.input_channels = input_channels
161
+ self.target_sources_num = target_sources_num
162
+
163
+ window_size = 2048
164
+ hop_size = 441
165
+ center = True
166
+ pad_mode = "reflect"
167
+ window = "hann"
168
+ activation = "relu"
169
+ momentum = 0.01
170
+
171
+ self.subbands_num = 4
172
+ self.K = 4 # outputs: |M|, cos∠M, sin∠M, |M2|
173
+
174
+ self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
175
+
176
+ self.stft = STFT(
177
+ n_fft=window_size,
178
+ hop_length=hop_size,
179
+ win_length=window_size,
180
+ window=window,
181
+ center=center,
182
+ pad_mode=pad_mode,
183
+ freeze_parameters=True,
184
+ )
185
+
186
+ self.istft = ISTFT(
187
+ n_fft=window_size,
188
+ hop_length=hop_size,
189
+ win_length=window_size,
190
+ window=window,
191
+ center=center,
192
+ pad_mode=pad_mode,
193
+ freeze_parameters=True,
194
+ )
195
+
196
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
197
+
198
+ self.subband = Subband(subbands_num=self.subbands_num)
199
+
200
+ self.encoder_block1 = EncoderBlockRes4B(
201
+ in_channels=input_channels * self.subbands_num,
202
+ out_channels=32,
203
+ kernel_size=(3, 3),
204
+ downsample=(2, 2),
205
+ activation=activation,
206
+ momentum=momentum,
207
+ )
208
+ self.encoder_block2 = EncoderBlockRes4B(
209
+ in_channels=32,
210
+ out_channels=64,
211
+ kernel_size=(3, 3),
212
+ downsample=(2, 2),
213
+ activation=activation,
214
+ momentum=momentum,
215
+ )
216
+ self.encoder_block3 = EncoderBlockRes4B(
217
+ in_channels=64,
218
+ out_channels=128,
219
+ kernel_size=(3, 3),
220
+ downsample=(2, 2),
221
+ activation=activation,
222
+ momentum=momentum,
223
+ )
224
+ self.encoder_block4 = EncoderBlockRes4B(
225
+ in_channels=128,
226
+ out_channels=256,
227
+ kernel_size=(3, 3),
228
+ downsample=(2, 2),
229
+ activation=activation,
230
+ momentum=momentum,
231
+ )
232
+ self.encoder_block5 = EncoderBlockRes4B(
233
+ in_channels=256,
234
+ out_channels=384,
235
+ kernel_size=(3, 3),
236
+ downsample=(2, 2),
237
+ activation=activation,
238
+ momentum=momentum,
239
+ )
240
+ self.encoder_block6 = EncoderBlockRes4B(
241
+ in_channels=384,
242
+ out_channels=384,
243
+ kernel_size=(3, 3),
244
+ downsample=(1, 2),
245
+ activation=activation,
246
+ momentum=momentum,
247
+ )
248
+ self.conv_block7a = EncoderBlockRes4B(
249
+ in_channels=384,
250
+ out_channels=384,
251
+ kernel_size=(3, 3),
252
+ downsample=(1, 1),
253
+ activation=activation,
254
+ momentum=momentum,
255
+ )
256
+ self.conv_block7b = EncoderBlockRes4B(
257
+ in_channels=384,
258
+ out_channels=384,
259
+ kernel_size=(3, 3),
260
+ downsample=(1, 1),
261
+ activation=activation,
262
+ momentum=momentum,
263
+ )
264
+ self.conv_block7c = EncoderBlockRes4B(
265
+ in_channels=384,
266
+ out_channels=384,
267
+ kernel_size=(3, 3),
268
+ downsample=(1, 1),
269
+ activation=activation,
270
+ momentum=momentum,
271
+ )
272
+ self.conv_block7d = EncoderBlockRes4B(
273
+ in_channels=384,
274
+ out_channels=384,
275
+ kernel_size=(3, 3),
276
+ downsample=(1, 1),
277
+ activation=activation,
278
+ momentum=momentum,
279
+ )
280
+ self.decoder_block1 = DecoderBlockRes4B(
281
+ in_channels=384,
282
+ out_channels=384,
283
+ kernel_size=(3, 3),
284
+ upsample=(1, 2),
285
+ activation=activation,
286
+ momentum=momentum,
287
+ )
288
+ self.decoder_block2 = DecoderBlockRes4B(
289
+ in_channels=384,
290
+ out_channels=384,
291
+ kernel_size=(3, 3),
292
+ upsample=(2, 2),
293
+ activation=activation,
294
+ momentum=momentum,
295
+ )
296
+ self.decoder_block3 = DecoderBlockRes4B(
297
+ in_channels=384,
298
+ out_channels=256,
299
+ kernel_size=(3, 3),
300
+ upsample=(2, 2),
301
+ activation=activation,
302
+ momentum=momentum,
303
+ )
304
+ self.decoder_block4 = DecoderBlockRes4B(
305
+ in_channels=256,
306
+ out_channels=128,
307
+ kernel_size=(3, 3),
308
+ upsample=(2, 2),
309
+ activation=activation,
310
+ momentum=momentum,
311
+ )
312
+ self.decoder_block5 = DecoderBlockRes4B(
313
+ in_channels=128,
314
+ out_channels=64,
315
+ kernel_size=(3, 3),
316
+ upsample=(2, 2),
317
+ activation=activation,
318
+ momentum=momentum,
319
+ )
320
+ self.decoder_block6 = DecoderBlockRes4B(
321
+ in_channels=64,
322
+ out_channels=32,
323
+ kernel_size=(3, 3),
324
+ upsample=(2, 2),
325
+ activation=activation,
326
+ momentum=momentum,
327
+ )
328
+
329
+ self.after_conv_block1 = EncoderBlockRes4B(
330
+ in_channels=32,
331
+ out_channels=32,
332
+ kernel_size=(3, 3),
333
+ downsample=(1, 1),
334
+ activation=activation,
335
+ momentum=momentum,
336
+ )
337
+
338
+ self.after_conv2 = nn.Conv2d(
339
+ in_channels=32,
340
+ out_channels=input_channels
341
+ * self.subbands_num
342
+ * target_sources_num
343
+ * self.K,
344
+ kernel_size=(1, 1),
345
+ stride=(1, 1),
346
+ padding=(0, 0),
347
+ bias=True,
348
+ )
349
+
350
+ self.init_weights()
351
+
352
+ def init_weights(self):
353
+ init_bn(self.bn0)
354
+ init_layer(self.after_conv2)
355
+
356
+ def feature_maps_to_wav(
357
+ self,
358
+ input_tensor: torch.Tensor,
359
+ sp: torch.Tensor,
360
+ sin_in: torch.Tensor,
361
+ cos_in: torch.Tensor,
362
+ audio_length: int,
363
+ ) -> torch.Tensor:
364
+ r"""Convert feature maps to waveform.
365
+
366
+ Args:
367
+ input_tensor: (batch_size, feature_maps, time_steps, freq_bins)
368
+ sp: (batch_size, feature_maps, time_steps, freq_bins)
369
+ sin_in: (batch_size, feature_maps, time_steps, freq_bins)
370
+ cos_in: (batch_size, feature_maps, time_steps, freq_bins)
371
+
372
+ Outputs:
373
+ waveform: (batch_size, target_sources_num * input_channels, segment_samples)
374
+ """
375
+ batch_size, _, time_steps, freq_bins = input_tensor.shape
376
+
377
+ x = input_tensor.reshape(
378
+ batch_size,
379
+ self.target_sources_num,
380
+ self.input_channels,
381
+ self.K,
382
+ time_steps,
383
+ freq_bins,
384
+ )
385
+ # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
386
+
387
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
388
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
389
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
390
+ linear_mag = x[:, :, :, 3, :, :]
391
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
392
+ # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
393
+
394
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
395
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
396
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
397
+ out_cos = (
398
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
399
+ )
400
+ out_sin = (
401
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
402
+ )
403
+ # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
404
+ # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
405
+
406
+ # Calculate |Y|.
407
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
408
+ # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
409
+
410
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
411
+ out_real = out_mag * out_cos
412
+ out_imag = out_mag * out_sin
413
+ # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
414
+
415
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
416
+ shape = (
417
+ batch_size * self.target_sources_num * self.input_channels,
418
+ 1,
419
+ time_steps,
420
+ freq_bins,
421
+ )
422
+ out_real = out_real.reshape(shape)
423
+ out_imag = out_imag.reshape(shape)
424
+
425
+ # ISTFT.
426
+ x = self.istft(out_real, out_imag, audio_length)
427
+ # (batch_size * target_sources_num * input_channels, segments_num)
428
+
429
+ # Reshape.
430
+ waveform = x.reshape(
431
+ batch_size, self.target_sources_num * self.input_channels, audio_length
432
+ )
433
+ # (batch_size, target_sources_num * input_channels, segments_num)
434
+
435
+ return waveform
436
+
437
+ def forward(self, input_dict):
438
+ r"""
439
+ Args:
440
+ input: (batch_size, channels_num, segment_samples)
441
+
442
+ Outputs:
443
+ output_dict: {
444
+ 'wav': (batch_size, channels_num, segment_samples)
445
+ }
446
+ """
447
+ mixtures = input_dict['waveform']
448
+ # (batch_size, input_channels, segment_samples)
449
+
450
+ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
451
+ # mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
452
+
453
+ # Batch normalize on individual frequency bins.
454
+ x = mag.transpose(1, 3)
455
+ x = self.bn0(x)
456
+ x = x.transpose(1, 3)
457
+ """(batch_size, input_channels, time_steps, freq_bins)"""
458
+
459
+ # Pad spectrogram to be evenly divided by downsample ratio.
460
+ origin_len = x.shape[2]
461
+ pad_len = (
462
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
463
+ - origin_len
464
+ )
465
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
466
+ """(batch_size, input_channels, padded_time_steps, freq_bins)"""
467
+
468
+ # Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024
469
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
470
+
471
+ x = self.subband.analysis(x)
472
+ # (bs, input_channels, T, F'), where F' = F // subbands_num
473
+
474
+ # UNet
475
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
476
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
477
+ (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
478
+ (x4_pool, x4) = self.encoder_block4(
479
+ x3_pool
480
+ ) # x4_pool: (bs, 256, T / 16, F / 16)
481
+ (x5_pool, x5) = self.encoder_block5(
482
+ x4_pool
483
+ ) # x5_pool: (bs, 384, T / 32, F / 32)
484
+ (x6_pool, x6) = self.encoder_block6(
485
+ x5_pool
486
+ ) # x6_pool: (bs, 384, T / 32, F / 64)
487
+ (x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
488
+ (x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
489
+ (x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
490
+ (x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
491
+ x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
492
+ x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
493
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
494
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
495
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
496
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
497
+ (x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
498
+
499
+ x = self.after_conv2(x) # (bs, channels * 3, T, F)
500
+ # (batch_size, input_channles * subbands_num * targets_num * k, T, F')
501
+
502
+ x = self.subband.synthesis(x)
503
+ # (batch_size, input_channles * targets_num * K, T, F)
504
+
505
+ # Recover shape
506
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
507
+ x = x[:, :, 0:origin_len, :] # (bs, feature_maps, time_steps, freq_bins)
508
+
509
+ audio_length = mixtures.shape[2]
510
+
511
+ separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
512
+ # separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
513
+
514
+ output_dict = {'waveform': separated_audio}
515
+
516
+ return output_dict
bytesep/models/resunet_ismir2021.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from inplace_abn.abn import InPlaceABNSync
6
+ from torchlibrosa.stft import ISTFT, STFT, magphase
7
+
8
+ from bytesep.models.pytorch_modules import Base, init_bn, init_layer
9
+
10
+
11
+ class ConvBlockRes(nn.Module):
12
+ def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
13
+ r"""Residual block."""
14
+ super(ConvBlockRes, self).__init__()
15
+
16
+ self.activation = activation
17
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
18
+
19
+ # ABN is not used for bn1 because we found using abn1 will degrade performance.
20
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
21
+
22
+ self.abn2 = InPlaceABNSync(
23
+ num_features=out_channels, momentum=momentum, activation='leaky_relu'
24
+ )
25
+
26
+ self.conv1 = nn.Conv2d(
27
+ in_channels=in_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=kernel_size,
30
+ stride=(1, 1),
31
+ dilation=(1, 1),
32
+ padding=padding,
33
+ bias=False,
34
+ )
35
+
36
+ self.conv2 = nn.Conv2d(
37
+ in_channels=out_channels,
38
+ out_channels=out_channels,
39
+ kernel_size=kernel_size,
40
+ stride=(1, 1),
41
+ dilation=(1, 1),
42
+ padding=padding,
43
+ bias=False,
44
+ )
45
+
46
+ if in_channels != out_channels:
47
+ self.shortcut = nn.Conv2d(
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ kernel_size=(1, 1),
51
+ stride=(1, 1),
52
+ padding=(0, 0),
53
+ )
54
+ self.is_shortcut = True
55
+ else:
56
+ self.is_shortcut = False
57
+
58
+ self.init_weights()
59
+
60
+ def init_weights(self):
61
+ init_bn(self.bn1)
62
+ init_layer(self.conv1)
63
+ init_layer(self.conv2)
64
+
65
+ if self.is_shortcut:
66
+ init_layer(self.shortcut)
67
+
68
+ def forward(self, x):
69
+ origin = x
70
+ x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
71
+ x = self.conv2(self.abn2(x))
72
+
73
+ if self.is_shortcut:
74
+ return self.shortcut(origin) + x
75
+ else:
76
+ return origin + x
77
+
78
+
79
+ class EncoderBlockRes4B(nn.Module):
80
+ def __init__(
81
+ self, in_channels, out_channels, kernel_size, downsample, activation, momentum
82
+ ):
83
+ r"""Encoder block, contains 8 convolutional layers."""
84
+ super(EncoderBlockRes4B, self).__init__()
85
+
86
+ self.conv_block1 = ConvBlockRes(
87
+ in_channels, out_channels, kernel_size, activation, momentum
88
+ )
89
+ self.conv_block2 = ConvBlockRes(
90
+ out_channels, out_channels, kernel_size, activation, momentum
91
+ )
92
+ self.conv_block3 = ConvBlockRes(
93
+ out_channels, out_channels, kernel_size, activation, momentum
94
+ )
95
+ self.conv_block4 = ConvBlockRes(
96
+ out_channels, out_channels, kernel_size, activation, momentum
97
+ )
98
+ self.downsample = downsample
99
+
100
+ def forward(self, x):
101
+ encoder = self.conv_block1(x)
102
+ encoder = self.conv_block2(encoder)
103
+ encoder = self.conv_block3(encoder)
104
+ encoder = self.conv_block4(encoder)
105
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
106
+ return encoder_pool, encoder
107
+
108
+
109
+ class DecoderBlockRes4B(nn.Module):
110
+ def __init__(
111
+ self, in_channels, out_channels, kernel_size, upsample, activation, momentum
112
+ ):
113
+ r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
114
+ super(DecoderBlockRes4B, self).__init__()
115
+ self.kernel_size = kernel_size
116
+ self.stride = upsample
117
+ self.activation = activation
118
+
119
+ self.conv1 = torch.nn.ConvTranspose2d(
120
+ in_channels=in_channels,
121
+ out_channels=out_channels,
122
+ kernel_size=self.stride,
123
+ stride=self.stride,
124
+ padding=(0, 0),
125
+ bias=False,
126
+ dilation=(1, 1),
127
+ )
128
+
129
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
130
+ self.conv_block2 = ConvBlockRes(
131
+ out_channels * 2, out_channels, kernel_size, activation, momentum
132
+ )
133
+ self.conv_block3 = ConvBlockRes(
134
+ out_channels, out_channels, kernel_size, activation, momentum
135
+ )
136
+ self.conv_block4 = ConvBlockRes(
137
+ out_channels, out_channels, kernel_size, activation, momentum
138
+ )
139
+ self.conv_block5 = ConvBlockRes(
140
+ out_channels, out_channels, kernel_size, activation, momentum
141
+ )
142
+
143
+ self.init_weights()
144
+
145
+ def init_weights(self):
146
+ init_bn(self.bn1)
147
+ init_layer(self.conv1)
148
+
149
+ def forward(self, input_tensor, concat_tensor):
150
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
151
+ x = torch.cat((x, concat_tensor), dim=1)
152
+ x = self.conv_block2(x)
153
+ x = self.conv_block3(x)
154
+ x = self.conv_block4(x)
155
+ x = self.conv_block5(x)
156
+ return x
157
+
158
+
159
+ class ResUNet143_DecouplePlusInplaceABN_ISMIR2021(nn.Module, Base):
160
+ def __init__(self, input_channels, target_sources_num):
161
+ super(ResUNet143_DecouplePlusInplaceABN_ISMIR2021, self).__init__()
162
+
163
+ self.input_channels = input_channels
164
+ self.target_sources_num = target_sources_num
165
+
166
+ window_size = 2048
167
+ hop_size = 441
168
+ center = True
169
+ pad_mode = 'reflect'
170
+ window = 'hann'
171
+ activation = 'leaky_relu'
172
+ momentum = 0.01
173
+
174
+ self.subbands_num = 1
175
+
176
+ assert (
177
+ self.subbands_num == 1
178
+ ), "Using subbands_num > 1 on spectrogram \
179
+ will lead to unexpected performance sometimes. Suggest to use \
180
+ subband method on waveform."
181
+
182
+ # Downsample rate along the time axis.
183
+ self.K = 4 # outputs: |M|, cos∠M, sin∠M, Q
184
+ self.time_downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks}
185
+
186
+ self.stft = STFT(
187
+ n_fft=window_size,
188
+ hop_length=hop_size,
189
+ win_length=window_size,
190
+ window=window,
191
+ center=center,
192
+ pad_mode=pad_mode,
193
+ freeze_parameters=True,
194
+ )
195
+
196
+ self.istft = ISTFT(
197
+ n_fft=window_size,
198
+ hop_length=hop_size,
199
+ win_length=window_size,
200
+ window=window,
201
+ center=center,
202
+ pad_mode=pad_mode,
203
+ freeze_parameters=True,
204
+ )
205
+
206
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
207
+
208
+ self.encoder_block1 = EncoderBlockRes4B(
209
+ in_channels=input_channels * self.subbands_num,
210
+ out_channels=32,
211
+ kernel_size=(3, 3),
212
+ downsample=(2, 2),
213
+ activation=activation,
214
+ momentum=momentum,
215
+ )
216
+ self.encoder_block2 = EncoderBlockRes4B(
217
+ in_channels=32,
218
+ out_channels=64,
219
+ kernel_size=(3, 3),
220
+ downsample=(2, 2),
221
+ activation=activation,
222
+ momentum=momentum,
223
+ )
224
+ self.encoder_block3 = EncoderBlockRes4B(
225
+ in_channels=64,
226
+ out_channels=128,
227
+ kernel_size=(3, 3),
228
+ downsample=(2, 2),
229
+ activation=activation,
230
+ momentum=momentum,
231
+ )
232
+ self.encoder_block4 = EncoderBlockRes4B(
233
+ in_channels=128,
234
+ out_channels=256,
235
+ kernel_size=(3, 3),
236
+ downsample=(2, 2),
237
+ activation=activation,
238
+ momentum=momentum,
239
+ )
240
+ self.encoder_block5 = EncoderBlockRes4B(
241
+ in_channels=256,
242
+ out_channels=384,
243
+ kernel_size=(3, 3),
244
+ downsample=(2, 2),
245
+ activation=activation,
246
+ momentum=momentum,
247
+ )
248
+ self.encoder_block6 = EncoderBlockRes4B(
249
+ in_channels=384,
250
+ out_channels=384,
251
+ kernel_size=(3, 3),
252
+ downsample=(1, 2),
253
+ activation=activation,
254
+ momentum=momentum,
255
+ )
256
+ self.conv_block7a = EncoderBlockRes4B(
257
+ in_channels=384,
258
+ out_channels=384,
259
+ kernel_size=(3, 3),
260
+ downsample=(1, 1),
261
+ activation=activation,
262
+ momentum=momentum,
263
+ )
264
+ self.conv_block7b = EncoderBlockRes4B(
265
+ in_channels=384,
266
+ out_channels=384,
267
+ kernel_size=(3, 3),
268
+ downsample=(1, 1),
269
+ activation=activation,
270
+ momentum=momentum,
271
+ )
272
+ self.conv_block7c = EncoderBlockRes4B(
273
+ in_channels=384,
274
+ out_channels=384,
275
+ kernel_size=(3, 3),
276
+ downsample=(1, 1),
277
+ activation=activation,
278
+ momentum=momentum,
279
+ )
280
+ self.conv_block7d = EncoderBlockRes4B(
281
+ in_channels=384,
282
+ out_channels=384,
283
+ kernel_size=(3, 3),
284
+ downsample=(1, 1),
285
+ activation=activation,
286
+ momentum=momentum,
287
+ )
288
+ self.decoder_block1 = DecoderBlockRes4B(
289
+ in_channels=384,
290
+ out_channels=384,
291
+ kernel_size=(3, 3),
292
+ upsample=(1, 2),
293
+ activation=activation,
294
+ momentum=momentum,
295
+ )
296
+ self.decoder_block2 = DecoderBlockRes4B(
297
+ in_channels=384,
298
+ out_channels=384,
299
+ kernel_size=(3, 3),
300
+ upsample=(2, 2),
301
+ activation=activation,
302
+ momentum=momentum,
303
+ )
304
+ self.decoder_block3 = DecoderBlockRes4B(
305
+ in_channels=384,
306
+ out_channels=256,
307
+ kernel_size=(3, 3),
308
+ upsample=(2, 2),
309
+ activation=activation,
310
+ momentum=momentum,
311
+ )
312
+ self.decoder_block4 = DecoderBlockRes4B(
313
+ in_channels=256,
314
+ out_channels=128,
315
+ kernel_size=(3, 3),
316
+ upsample=(2, 2),
317
+ activation=activation,
318
+ momentum=momentum,
319
+ )
320
+ self.decoder_block5 = DecoderBlockRes4B(
321
+ in_channels=128,
322
+ out_channels=64,
323
+ kernel_size=(3, 3),
324
+ upsample=(2, 2),
325
+ activation=activation,
326
+ momentum=momentum,
327
+ )
328
+ self.decoder_block6 = DecoderBlockRes4B(
329
+ in_channels=64,
330
+ out_channels=32,
331
+ kernel_size=(3, 3),
332
+ upsample=(2, 2),
333
+ activation=activation,
334
+ momentum=momentum,
335
+ )
336
+
337
+ self.after_conv_block1 = EncoderBlockRes4B(
338
+ in_channels=32,
339
+ out_channels=32,
340
+ kernel_size=(3, 3),
341
+ downsample=(1, 1),
342
+ activation=activation,
343
+ momentum=momentum,
344
+ )
345
+
346
+ self.after_conv2 = nn.Conv2d(
347
+ in_channels=32,
348
+ out_channels=target_sources_num
349
+ * input_channels
350
+ * self.K
351
+ * self.subbands_num,
352
+ kernel_size=(1, 1),
353
+ stride=(1, 1),
354
+ padding=(0, 0),
355
+ bias=True,
356
+ )
357
+
358
+ self.init_weights()
359
+
360
+ def init_weights(self):
361
+ init_bn(self.bn0)
362
+ init_layer(self.after_conv2)
363
+
364
+ def feature_maps_to_wav(
365
+ self,
366
+ input_tensor: torch.Tensor,
367
+ sp: torch.Tensor,
368
+ sin_in: torch.Tensor,
369
+ cos_in: torch.Tensor,
370
+ audio_length: int,
371
+ ) -> torch.Tensor:
372
+ r"""Convert feature maps to waveform.
373
+
374
+ Args:
375
+ input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
376
+ sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
377
+ sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
378
+ cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
379
+
380
+ Outputs:
381
+ waveform: (batch_size, target_sources_num * input_channels, segment_samples)
382
+ """
383
+ batch_size, _, time_steps, freq_bins = input_tensor.shape
384
+
385
+ x = input_tensor.reshape(
386
+ batch_size,
387
+ self.target_sources_num,
388
+ self.input_channels,
389
+ self.K,
390
+ time_steps,
391
+ freq_bins,
392
+ )
393
+ # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
394
+
395
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
396
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
397
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
398
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
399
+ linear_mag = x[:, :, :, 3, :, :]
400
+ # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
401
+
402
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
403
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
404
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
405
+ out_cos = (
406
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
407
+ )
408
+ out_sin = (
409
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
410
+ )
411
+ # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
412
+ # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
413
+
414
+ # Calculate |Y|.
415
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
416
+ # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
417
+
418
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
419
+ out_real = out_mag * out_cos
420
+ out_imag = out_mag * out_sin
421
+ # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
422
+
423
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
424
+ shape = (
425
+ batch_size * self.target_sources_num * self.input_channels,
426
+ 1,
427
+ time_steps,
428
+ freq_bins,
429
+ )
430
+ out_real = out_real.reshape(shape)
431
+ out_imag = out_imag.reshape(shape)
432
+
433
+ # ISTFT.
434
+ x = self.istft(out_real, out_imag, audio_length)
435
+ # (batch_size * target_sources_num * input_channels, segments_num)
436
+
437
+ # Reshape.
438
+ waveform = x.reshape(
439
+ batch_size, self.target_sources_num * self.input_channels, audio_length
440
+ )
441
+ # (batch_size, target_sources_num * input_channels, segments_num)
442
+
443
+ return waveform
444
+
445
+ def forward(self, input_dict):
446
+ r"""Forward data into the module.
447
+
448
+ Args:
449
+ input_dict: dict, e.g., {
450
+ waveform: (batch_size, input_channels, segment_samples),
451
+ ...,
452
+ }
453
+
454
+ Outputs:
455
+ output_dict: dict, e.g., {
456
+ 'waveform': (batch_size, input_channels, segment_samples),
457
+ ...,
458
+ }
459
+ """
460
+ mixtures = input_dict['waveform']
461
+ # (batch_size, input_channels, segment_samples)
462
+
463
+ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
464
+ # mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
465
+
466
+ # Batch normalize on individual frequency bins.
467
+ x = mag.transpose(1, 3)
468
+ x = self.bn0(x)
469
+ x = x.transpose(1, 3)
470
+ # x: (batch_size, input_channels, time_steps, freq_bins)
471
+
472
+ # Pad spectrogram to be evenly divided by downsample ratio.
473
+ origin_len = x.shape[2]
474
+ pad_len = (
475
+ int(np.ceil(x.shape[2] / self.time_downsample_ratio))
476
+ * self.time_downsample_ratio
477
+ - origin_len
478
+ )
479
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
480
+ # (batch_size, channels, padded_time_steps, freq_bins)
481
+
482
+ # Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024.
483
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
484
+
485
+ if self.subbands_num > 1:
486
+ x = self.subband.analysis(x)
487
+ # (bs, input_channels, T, F'), where F' = F // subbands_num
488
+
489
+ # UNet
490
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
491
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
492
+ (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
493
+ (x4_pool, x4) = self.encoder_block4(
494
+ x3_pool
495
+ ) # x4_pool: (bs, 256, T / 16, F / 16)
496
+ (x5_pool, x5) = self.encoder_block5(
497
+ x4_pool
498
+ ) # x5_pool: (bs, 384, T / 32, F / 32)
499
+ (x6_pool, x6) = self.encoder_block6(
500
+ x5_pool
501
+ ) # x6_pool: (bs, 384, T / 32, F / 64)
502
+ (x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
503
+ (x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
504
+ (x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
505
+ (x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
506
+ x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
507
+ x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
508
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
509
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
510
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
511
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
512
+ (x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
513
+
514
+ x = self.after_conv2(x) # (bs, channels * 3, T, F)
515
+ # (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F')
516
+
517
+ if self.subbands_num > 1:
518
+ x = self.subband.synthesis(x)
519
+ # (batch_size, target_sources_num * input_channles * self.K, T, F)
520
+
521
+ # Recover shape
522
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
523
+
524
+ x = x[:, :, 0:origin_len, :]
525
+ # (batch_size, target_sources_num * input_channles * self.K, T, F)
526
+
527
+ audio_length = mixtures.shape[2]
528
+
529
+ separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
530
+ # separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
531
+
532
+ output_dict = {'waveform': separated_audio}
533
+
534
+ return output_dict
bytesep/models/resunet_subbandtime.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchlibrosa.stft import ISTFT, STFT, magphase
6
+
7
+ from bytesep.models.pytorch_modules import Base, init_bn, init_layer
8
+ from bytesep.models.subband_tools.pqmf import PQMF
9
+
10
+
11
+ class ConvBlockRes(nn.Module):
12
+ def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
13
+ r"""Residual block."""
14
+ super(ConvBlockRes, self).__init__()
15
+
16
+ self.activation = activation
17
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
18
+
19
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
20
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
21
+
22
+ self.conv1 = nn.Conv2d(
23
+ in_channels=in_channels,
24
+ out_channels=out_channels,
25
+ kernel_size=kernel_size,
26
+ stride=(1, 1),
27
+ dilation=(1, 1),
28
+ padding=padding,
29
+ bias=False,
30
+ )
31
+
32
+ self.conv2 = nn.Conv2d(
33
+ in_channels=out_channels,
34
+ out_channels=out_channels,
35
+ kernel_size=kernel_size,
36
+ stride=(1, 1),
37
+ dilation=(1, 1),
38
+ padding=padding,
39
+ bias=False,
40
+ )
41
+
42
+ if in_channels != out_channels:
43
+ self.shortcut = nn.Conv2d(
44
+ in_channels=in_channels,
45
+ out_channels=out_channels,
46
+ kernel_size=(1, 1),
47
+ stride=(1, 1),
48
+ padding=(0, 0),
49
+ )
50
+ self.is_shortcut = True
51
+ else:
52
+ self.is_shortcut = False
53
+
54
+ self.init_weights()
55
+
56
+ def init_weights(self):
57
+ init_bn(self.bn1)
58
+ init_bn(self.bn2)
59
+ init_layer(self.conv1)
60
+ init_layer(self.conv2)
61
+
62
+ if self.is_shortcut:
63
+ init_layer(self.shortcut)
64
+
65
+ def forward(self, x):
66
+ origin = x
67
+ x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
68
+ x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
69
+
70
+ if self.is_shortcut:
71
+ return self.shortcut(origin) + x
72
+ else:
73
+ return origin + x
74
+
75
+
76
+ class EncoderBlockRes4B(nn.Module):
77
+ def __init__(
78
+ self, in_channels, out_channels, kernel_size, downsample, activation, momentum
79
+ ):
80
+ r"""Encoder block, contains 8 convolutional layers."""
81
+ super(EncoderBlockRes4B, self).__init__()
82
+
83
+ self.conv_block1 = ConvBlockRes(
84
+ in_channels, out_channels, kernel_size, activation, momentum
85
+ )
86
+ self.conv_block2 = ConvBlockRes(
87
+ out_channels, out_channels, kernel_size, activation, momentum
88
+ )
89
+ self.conv_block3 = ConvBlockRes(
90
+ out_channels, out_channels, kernel_size, activation, momentum
91
+ )
92
+ self.conv_block4 = ConvBlockRes(
93
+ out_channels, out_channels, kernel_size, activation, momentum
94
+ )
95
+ self.downsample = downsample
96
+
97
+ def forward(self, x):
98
+ encoder = self.conv_block1(x)
99
+ encoder = self.conv_block2(encoder)
100
+ encoder = self.conv_block3(encoder)
101
+ encoder = self.conv_block4(encoder)
102
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
103
+ return encoder_pool, encoder
104
+
105
+
106
+ class DecoderBlockRes4B(nn.Module):
107
+ def __init__(
108
+ self, in_channels, out_channels, kernel_size, upsample, activation, momentum
109
+ ):
110
+ r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
111
+ super(DecoderBlockRes4B, self).__init__()
112
+ self.kernel_size = kernel_size
113
+ self.stride = upsample
114
+ self.activation = activation
115
+
116
+ self.conv1 = torch.nn.ConvTranspose2d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=self.stride,
120
+ stride=self.stride,
121
+ padding=(0, 0),
122
+ bias=False,
123
+ dilation=(1, 1),
124
+ )
125
+
126
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
127
+ self.conv_block2 = ConvBlockRes(
128
+ out_channels * 2, out_channels, kernel_size, activation, momentum
129
+ )
130
+ self.conv_block3 = ConvBlockRes(
131
+ out_channels, out_channels, kernel_size, activation, momentum
132
+ )
133
+ self.conv_block4 = ConvBlockRes(
134
+ out_channels, out_channels, kernel_size, activation, momentum
135
+ )
136
+ self.conv_block5 = ConvBlockRes(
137
+ out_channels, out_channels, kernel_size, activation, momentum
138
+ )
139
+
140
+ self.init_weights()
141
+
142
+ def init_weights(self):
143
+ init_bn(self.bn1)
144
+ init_layer(self.conv1)
145
+
146
+ def forward(self, input_tensor, concat_tensor):
147
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
148
+ x = torch.cat((x, concat_tensor), dim=1)
149
+ x = self.conv_block2(x)
150
+ x = self.conv_block3(x)
151
+ x = self.conv_block4(x)
152
+ x = self.conv_block5(x)
153
+ return x
154
+
155
+
156
+ class ResUNet143_Subbandtime(nn.Module, Base):
157
+ def __init__(self, input_channels, target_sources_num):
158
+ super(ResUNet143_Subbandtime, self).__init__()
159
+
160
+ self.input_channels = input_channels
161
+ self.target_sources_num = target_sources_num
162
+
163
+ window_size = 512
164
+ hop_size = 110
165
+ center = True
166
+ pad_mode = "reflect"
167
+ window = "hann"
168
+ activation = "leaky_relu"
169
+ momentum = 0.01
170
+
171
+ self.subbands_num = 4
172
+ self.K = 4 # outputs: |M|, cos∠M, sin∠M, Q
173
+
174
+ self.downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks}
175
+
176
+ self.pqmf = PQMF(
177
+ N=self.subbands_num,
178
+ M=64,
179
+ project_root='bytesep/models/subband_tools/filters',
180
+ )
181
+
182
+ self.stft = STFT(
183
+ n_fft=window_size,
184
+ hop_length=hop_size,
185
+ win_length=window_size,
186
+ window=window,
187
+ center=center,
188
+ pad_mode=pad_mode,
189
+ freeze_parameters=True,
190
+ )
191
+
192
+ self.istft = ISTFT(
193
+ n_fft=window_size,
194
+ hop_length=hop_size,
195
+ win_length=window_size,
196
+ window=window,
197
+ center=center,
198
+ pad_mode=pad_mode,
199
+ freeze_parameters=True,
200
+ )
201
+
202
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
203
+
204
+ self.encoder_block1 = EncoderBlockRes4B(
205
+ in_channels=input_channels * self.subbands_num,
206
+ out_channels=32,
207
+ kernel_size=(3, 3),
208
+ downsample=(2, 2),
209
+ activation=activation,
210
+ momentum=momentum,
211
+ )
212
+ self.encoder_block2 = EncoderBlockRes4B(
213
+ in_channels=32,
214
+ out_channels=64,
215
+ kernel_size=(3, 3),
216
+ downsample=(2, 2),
217
+ activation=activation,
218
+ momentum=momentum,
219
+ )
220
+ self.encoder_block3 = EncoderBlockRes4B(
221
+ in_channels=64,
222
+ out_channels=128,
223
+ kernel_size=(3, 3),
224
+ downsample=(2, 2),
225
+ activation=activation,
226
+ momentum=momentum,
227
+ )
228
+ self.encoder_block4 = EncoderBlockRes4B(
229
+ in_channels=128,
230
+ out_channels=256,
231
+ kernel_size=(3, 3),
232
+ downsample=(2, 2),
233
+ activation=activation,
234
+ momentum=momentum,
235
+ )
236
+ self.encoder_block5 = EncoderBlockRes4B(
237
+ in_channels=256,
238
+ out_channels=384,
239
+ kernel_size=(3, 3),
240
+ downsample=(2, 2),
241
+ activation=activation,
242
+ momentum=momentum,
243
+ )
244
+ self.encoder_block6 = EncoderBlockRes4B(
245
+ in_channels=384,
246
+ out_channels=384,
247
+ kernel_size=(3, 3),
248
+ downsample=(1, 2),
249
+ activation=activation,
250
+ momentum=momentum,
251
+ )
252
+ self.conv_block7a = EncoderBlockRes4B(
253
+ in_channels=384,
254
+ out_channels=384,
255
+ kernel_size=(3, 3),
256
+ downsample=(1, 1),
257
+ activation=activation,
258
+ momentum=momentum,
259
+ )
260
+ self.conv_block7b = EncoderBlockRes4B(
261
+ in_channels=384,
262
+ out_channels=384,
263
+ kernel_size=(3, 3),
264
+ downsample=(1, 1),
265
+ activation=activation,
266
+ momentum=momentum,
267
+ )
268
+ self.conv_block7c = EncoderBlockRes4B(
269
+ in_channels=384,
270
+ out_channels=384,
271
+ kernel_size=(3, 3),
272
+ downsample=(1, 1),
273
+ activation=activation,
274
+ momentum=momentum,
275
+ )
276
+ self.conv_block7d = EncoderBlockRes4B(
277
+ in_channels=384,
278
+ out_channels=384,
279
+ kernel_size=(3, 3),
280
+ downsample=(1, 1),
281
+ activation=activation,
282
+ momentum=momentum,
283
+ )
284
+ self.decoder_block1 = DecoderBlockRes4B(
285
+ in_channels=384,
286
+ out_channels=384,
287
+ kernel_size=(3, 3),
288
+ upsample=(1, 2),
289
+ activation=activation,
290
+ momentum=momentum,
291
+ )
292
+ self.decoder_block2 = DecoderBlockRes4B(
293
+ in_channels=384,
294
+ out_channels=384,
295
+ kernel_size=(3, 3),
296
+ upsample=(2, 2),
297
+ activation=activation,
298
+ momentum=momentum,
299
+ )
300
+ self.decoder_block3 = DecoderBlockRes4B(
301
+ in_channels=384,
302
+ out_channels=256,
303
+ kernel_size=(3, 3),
304
+ upsample=(2, 2),
305
+ activation=activation,
306
+ momentum=momentum,
307
+ )
308
+ self.decoder_block4 = DecoderBlockRes4B(
309
+ in_channels=256,
310
+ out_channels=128,
311
+ kernel_size=(3, 3),
312
+ upsample=(2, 2),
313
+ activation=activation,
314
+ momentum=momentum,
315
+ )
316
+ self.decoder_block5 = DecoderBlockRes4B(
317
+ in_channels=128,
318
+ out_channels=64,
319
+ kernel_size=(3, 3),
320
+ upsample=(2, 2),
321
+ activation=activation,
322
+ momentum=momentum,
323
+ )
324
+ self.decoder_block6 = DecoderBlockRes4B(
325
+ in_channels=64,
326
+ out_channels=32,
327
+ kernel_size=(3, 3),
328
+ upsample=(2, 2),
329
+ activation=activation,
330
+ momentum=momentum,
331
+ )
332
+
333
+ self.after_conv_block1 = EncoderBlockRes4B(
334
+ in_channels=32,
335
+ out_channels=32,
336
+ kernel_size=(3, 3),
337
+ downsample=(1, 1),
338
+ activation=activation,
339
+ momentum=momentum,
340
+ )
341
+
342
+ self.after_conv2 = nn.Conv2d(
343
+ in_channels=32,
344
+ out_channels=target_sources_num
345
+ * input_channels
346
+ * self.K
347
+ * self.subbands_num,
348
+ kernel_size=(1, 1),
349
+ stride=(1, 1),
350
+ padding=(0, 0),
351
+ bias=True,
352
+ )
353
+
354
+ self.init_weights()
355
+
356
+ def init_weights(self):
357
+ init_bn(self.bn0)
358
+ init_layer(self.after_conv2)
359
+
360
+ def feature_maps_to_wav(
361
+ self,
362
+ input_tensor: torch.Tensor,
363
+ sp: torch.Tensor,
364
+ sin_in: torch.Tensor,
365
+ cos_in: torch.Tensor,
366
+ audio_length: int,
367
+ ) -> torch.Tensor:
368
+ r"""Convert feature maps to waveform.
369
+
370
+ Args:
371
+ input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
372
+ sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
373
+ sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
374
+ cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
375
+
376
+ Outputs:
377
+ waveform: (batch_size, target_sources_num * input_channels, segment_samples)
378
+ """
379
+ batch_size, _, time_steps, freq_bins = input_tensor.shape
380
+
381
+ x = input_tensor.reshape(
382
+ batch_size,
383
+ self.target_sources_num,
384
+ self.input_channels,
385
+ self.K,
386
+ time_steps,
387
+ freq_bins,
388
+ )
389
+ # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
390
+
391
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
392
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
393
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
394
+ linear_mag = torch.tanh(x[:, :, :, 3, :, :])
395
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
396
+ # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
397
+
398
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
399
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
400
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
401
+ out_cos = (
402
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
403
+ )
404
+ out_sin = (
405
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
406
+ )
407
+ # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
408
+ # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
409
+
410
+ # Calculate |Y|.
411
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
412
+ # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
413
+
414
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
415
+ out_real = out_mag * out_cos
416
+ out_imag = out_mag * out_sin
417
+ # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
418
+
419
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
420
+ shape = (
421
+ batch_size * self.target_sources_num * self.input_channels,
422
+ 1,
423
+ time_steps,
424
+ freq_bins,
425
+ )
426
+ out_real = out_real.reshape(shape)
427
+ out_imag = out_imag.reshape(shape)
428
+
429
+ # ISTFT.
430
+ x = self.istft(out_real, out_imag, audio_length)
431
+ # (batch_size * target_sources_num * input_channels, segments_num)
432
+
433
+ # Reshape.
434
+ waveform = x.reshape(
435
+ batch_size, self.target_sources_num * self.input_channels, audio_length
436
+ )
437
+ # (batch_size, target_sources_num * input_channels, segments_num)
438
+
439
+ return waveform
440
+
441
+ def forward(self, input_dict):
442
+ r"""Forward data into the module.
443
+
444
+ Args:
445
+ input_dict: dict, e.g., {
446
+ waveform: (batch_size, input_channels, segment_samples),
447
+ ...,
448
+ }
449
+
450
+ Outputs:
451
+ output_dict: dict, e.g., {
452
+ 'waveform': (batch_size, input_channels, segment_samples),
453
+ ...,
454
+ }
455
+ """
456
+ mixtures = input_dict['waveform']
457
+ # (batch_size, input_channels, segment_samples)
458
+
459
+ subband_x = self.pqmf.analysis(mixtures)
460
+ # subband_x: (batch_size, input_channels * subbands_num, segment_samples)
461
+
462
+ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
463
+ # mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)
464
+
465
+ # Batch normalize on individual frequency bins.
466
+ x = mag.transpose(1, 3)
467
+ x = self.bn0(x)
468
+ x = x.transpose(1, 3)
469
+ # (batch_size, input_channels * subbands_num, time_steps, freq_bins)
470
+
471
+ # Pad spectrogram to be evenly divided by downsample ratio.
472
+ origin_len = x.shape[2]
473
+ pad_len = (
474
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
475
+ - origin_len
476
+ )
477
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
478
+ # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
479
+
480
+ # Let frequency bins be evenly divided by 2, e.g., 257 -> 256
481
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
482
+ # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
483
+
484
+ # UNet
485
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
486
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
487
+ (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
488
+ (x4_pool, x4) = self.encoder_block4(
489
+ x3_pool
490
+ ) # x4_pool: (bs, 256, T / 16, F / 16)
491
+ (x5_pool, x5) = self.encoder_block5(
492
+ x4_pool
493
+ ) # x5_pool: (bs, 384, T / 32, F / 32)
494
+ (x6_pool, x6) = self.encoder_block6(
495
+ x5_pool
496
+ ) # x6_pool: (bs, 384, T / 32, F / 64)
497
+ (x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
498
+ (x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
499
+ (x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
500
+ (x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
501
+ x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
502
+ x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
503
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
504
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
505
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
506
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
507
+ (x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
508
+
509
+ x = self.after_conv2(x)
510
+ # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
511
+
512
+ # Recover shape
513
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 256 -> 257.
514
+
515
+ x = x[:, :, 0:origin_len, :]
516
+ # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
517
+
518
+ audio_length = subband_x.shape[2]
519
+
520
+ # Recover each subband spectrograms to subband waveforms. Then synthesis
521
+ # the subband waveforms to a waveform.
522
+ C1 = x.shape[1] // self.subbands_num
523
+ C2 = mag.shape[1] // self.subbands_num
524
+
525
+ separated_subband_audio = torch.cat(
526
+ [
527
+ self.feature_maps_to_wav(
528
+ input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
529
+ sp=mag[:, j * C2 : (j + 1) * C2, :, :],
530
+ sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
531
+ cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
532
+ audio_length=audio_length,
533
+ )
534
+ for j in range(self.subbands_num)
535
+ ],
536
+ dim=1,
537
+ )
538
+ # (batch_size, subbands_num * target_sources_num * input_channles, segment_samples)
539
+
540
+ separated_audio = self.pqmf.synthesis(separated_subband_audio)
541
+ # (batch_size, input_channles, segment_samples)
542
+
543
+ output_dict = {'waveform': separated_audio}
544
+
545
+ return output_dict
bytesep/models/subband_tools/__init__.py ADDED
File without changes
bytesep/models/subband_tools/fDomainHelper.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchlibrosa.stft import STFT, ISTFT, magphase
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from tools.pytorch.modules.pqmf import PQMF
6
+
7
+
8
+ class FDomainHelper(nn.Module):
9
+ def __init__(
10
+ self,
11
+ window_size=2048,
12
+ hop_size=441,
13
+ center=True,
14
+ pad_mode='reflect',
15
+ window='hann',
16
+ freeze_parameters=True,
17
+ subband=None,
18
+ root="/Users/admin/Documents/projects/",
19
+ ):
20
+ super(FDomainHelper, self).__init__()
21
+ self.subband = subband
22
+ if self.subband is None:
23
+ self.stft = STFT(
24
+ n_fft=window_size,
25
+ hop_length=hop_size,
26
+ win_length=window_size,
27
+ window=window,
28
+ center=center,
29
+ pad_mode=pad_mode,
30
+ freeze_parameters=freeze_parameters,
31
+ )
32
+
33
+ self.istft = ISTFT(
34
+ n_fft=window_size,
35
+ hop_length=hop_size,
36
+ win_length=window_size,
37
+ window=window,
38
+ center=center,
39
+ pad_mode=pad_mode,
40
+ freeze_parameters=freeze_parameters,
41
+ )
42
+ else:
43
+ self.stft = STFT(
44
+ n_fft=window_size // self.subband,
45
+ hop_length=hop_size // self.subband,
46
+ win_length=window_size // self.subband,
47
+ window=window,
48
+ center=center,
49
+ pad_mode=pad_mode,
50
+ freeze_parameters=freeze_parameters,
51
+ )
52
+
53
+ self.istft = ISTFT(
54
+ n_fft=window_size // self.subband,
55
+ hop_length=hop_size // self.subband,
56
+ win_length=window_size // self.subband,
57
+ window=window,
58
+ center=center,
59
+ pad_mode=pad_mode,
60
+ freeze_parameters=freeze_parameters,
61
+ )
62
+
63
+ if subband is not None and root is not None:
64
+ self.qmf = PQMF(subband, 64, root)
65
+
66
+ def complex_spectrogram(self, input, eps=0.0):
67
+ # [batchsize, samples]
68
+ # return [batchsize, 2, t-steps, f-bins]
69
+ real, imag = self.stft(input)
70
+ return torch.cat([real, imag], dim=1)
71
+
72
+ def reverse_complex_spectrogram(self, input, eps=0.0, length=None):
73
+ # [batchsize, 2[real,imag], t-steps, f-bins]
74
+ wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length)
75
+ return wav
76
+
77
+ def spectrogram(self, input, eps=0.0):
78
+ (real, imag) = self.stft(input.float())
79
+ return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
80
+
81
+ def spectrogram_phase(self, input, eps=0.0):
82
+ (real, imag) = self.stft(input.float())
83
+ mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
84
+ cos = real / mag
85
+ sin = imag / mag
86
+ return mag, cos, sin
87
+
88
+ def wav_to_spectrogram_phase(self, input, eps=1e-8):
89
+ """Waveform to spectrogram.
90
+
91
+ Args:
92
+ input: (batch_size, channels_num, segment_samples)
93
+
94
+ Outputs:
95
+ output: (batch_size, channels_num, time_steps, freq_bins)
96
+ """
97
+ sp_list = []
98
+ cos_list = []
99
+ sin_list = []
100
+ channels_num = input.shape[1]
101
+ for channel in range(channels_num):
102
+ mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps)
103
+ sp_list.append(mag)
104
+ cos_list.append(cos)
105
+ sin_list.append(sin)
106
+
107
+ sps = torch.cat(sp_list, dim=1)
108
+ coss = torch.cat(cos_list, dim=1)
109
+ sins = torch.cat(sin_list, dim=1)
110
+ return sps, coss, sins
111
+
112
+ def spectrogram_phase_to_wav(self, sps, coss, sins, length):
113
+ channels_num = sps.size()[1]
114
+ res = []
115
+ for i in range(channels_num):
116
+ res.append(
117
+ self.istft(
118
+ sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...],
119
+ sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...],
120
+ length,
121
+ )
122
+ )
123
+ res[-1] = res[-1].unsqueeze(1)
124
+ return torch.cat(res, dim=1)
125
+
126
+ def wav_to_spectrogram(self, input, eps=1e-8):
127
+ """Waveform to spectrogram.
128
+
129
+ Args:
130
+ input: (batch_size,channels_num, segment_samples)
131
+
132
+ Outputs:
133
+ output: (batch_size, channels_num, time_steps, freq_bins)
134
+ """
135
+ sp_list = []
136
+ channels_num = input.shape[1]
137
+ for channel in range(channels_num):
138
+ sp_list.append(self.spectrogram(input[:, channel, :], eps=eps))
139
+ output = torch.cat(sp_list, dim=1)
140
+ return output
141
+
142
+ def spectrogram_to_wav(self, input, spectrogram, length=None):
143
+ """Spectrogram to waveform.
144
+ Args:
145
+ input: (batch_size, segment_samples, channels_num)
146
+ spectrogram: (batch_size, channels_num, time_steps, freq_bins)
147
+
148
+ Outputs:
149
+ output: (batch_size, segment_samples, channels_num)
150
+ """
151
+ channels_num = input.shape[1]
152
+ wav_list = []
153
+ for channel in range(channels_num):
154
+ (real, imag) = self.stft(input[:, channel, :])
155
+ (_, cos, sin) = magphase(real, imag)
156
+ wav_list.append(
157
+ self.istft(
158
+ spectrogram[:, channel : channel + 1, :, :] * cos,
159
+ spectrogram[:, channel : channel + 1, :, :] * sin,
160
+ length,
161
+ )
162
+ )
163
+
164
+ output = torch.stack(wav_list, dim=1)
165
+ return output
166
+
167
+ # todo the following code is not bug free!
168
+ def wav_to_complex_spectrogram(self, input, eps=0.0):
169
+ # [batchsize , channels, samples]
170
+ # [batchsize, 2[real,imag]*channels, t-steps, f-bins]
171
+ res = []
172
+ channels_num = input.shape[1]
173
+ for channel in range(channels_num):
174
+ res.append(self.complex_spectrogram(input[:, channel, :], eps=eps))
175
+ return torch.cat(res, dim=1)
176
+
177
+ def complex_spectrogram_to_wav(self, input, eps=0.0, length=None):
178
+ # [batchsize, 2[real,imag]*channels, t-steps, f-bins]
179
+ # return [batchsize, channels, samples]
180
+ channels = input.size()[1] // 2
181
+ wavs = []
182
+ for i in range(channels):
183
+ wavs.append(
184
+ self.reverse_complex_spectrogram(
185
+ input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length
186
+ )
187
+ )
188
+ wavs[-1] = wavs[-1].unsqueeze(1)
189
+ return torch.cat(wavs, dim=1)
190
+
191
+ def wav_to_complex_subband_spectrogram(self, input, eps=0.0):
192
+ # [batchsize, channels, samples]
193
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
194
+ subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
195
+ subspec = self.wav_to_complex_spectrogram(subwav)
196
+ return subspec
197
+
198
+ def complex_subband_spectrogram_to_wav(self, input, eps=0.0):
199
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
200
+ # [batchsize, channels, samples]
201
+ subwav = self.complex_spectrogram_to_wav(input)
202
+ data = self.qmf.synthesis(subwav)
203
+ return data
204
+
205
+ def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8):
206
+ """
207
+ :param input:
208
+ :param eps:
209
+ :return:
210
+ loss = torch.nn.L1Loss()
211
+ model = FDomainHelper(subband=4)
212
+ data = torch.randn((3,1, 44100*3))
213
+
214
+ sps, coss, sins = model.wav_to_mag_phase_subband_spectrogram(data)
215
+ wav = model.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4)
216
+
217
+ print(loss(data,wav))
218
+ print(torch.max(torch.abs(data-wav)))
219
+
220
+ """
221
+ # [batchsize, channels, samples]
222
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
223
+ subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
224
+ sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps)
225
+ return sps, coss, sins
226
+
227
+ def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0):
228
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
229
+ # [batchsize, channels, samples]
230
+ subwav = self.spectrogram_phase_to_wav(sps, coss, sins, length)
231
+ data = self.qmf.synthesis(subwav)
232
+ return data
233
+
234
+
235
+ if __name__ == "__main__":
236
+ # from thop import profile
237
+ # from thop import clever_format
238
+ # from tools.file.wav import *
239
+ # import time
240
+ #
241
+ # wav = torch.randn((1,2,44100))
242
+ # model = FDomainHelper()
243
+
244
+ from tools.file.wav import *
245
+
246
+ loss = torch.nn.L1Loss()
247
+ model = FDomainHelper()
248
+ data = torch.randn((3, 1, 44100 * 5))
249
+
250
+ sps = model.wav_to_complex_spectrogram(data)
251
+ print(sps.size())
252
+ wav = model.complex_spectrogram_to_wav(sps, 44100 * 5)
253
+
254
+ print(loss(data, wav))
255
+ print(torch.max(torch.abs(data - wav)))
bytesep/models/subband_tools/filters/f_4_64.mat ADDED
Binary file (2.19 kB). View file
 
bytesep/models/subband_tools/filters/h_4_64.mat ADDED
Binary file (2.19 kB). View file
 
bytesep/models/subband_tools/pqmf.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : subband_util.py
3
+ @Contact : liu.8948@buckeyemail.osu.edu
4
+ @License : (C)Copyright 2020-2021
5
+ @Modify Time @Author @Version @Desciption
6
+ ------------ ------- -------- -----------
7
+ 2020/4/3 4:54 PM Haohe Liu 1.0 None
8
+ '''
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.nn as nn
13
+ import numpy as np
14
+ import os.path as op
15
+ from scipy.io import loadmat
16
+
17
+
18
+ def load_mat2numpy(fname=""):
19
+ '''
20
+ Args:
21
+ fname: pth to mat
22
+ type:
23
+ Returns: dic object
24
+ '''
25
+ if len(fname) == 0:
26
+ return None
27
+ else:
28
+ return loadmat(fname)
29
+
30
+
31
+ class PQMF(nn.Module):
32
+ def __init__(self, N, M, project_root):
33
+ super().__init__()
34
+ self.N = N # nsubband
35
+ self.M = M # nfilter
36
+ try:
37
+ assert (N, M) in [(8, 64), (4, 64), (2, 64)]
38
+ except:
39
+ print("Warning:", N, "subbandand ", M, " filter is not supported")
40
+ self.pad_samples = 64
41
+ self.name = str(N) + "_" + str(M) + ".mat"
42
+ self.ana_conv_filter = nn.Conv1d(
43
+ 1, out_channels=N, kernel_size=M, stride=N, bias=False
44
+ )
45
+ data = load_mat2numpy(op.join(project_root, "f_" + self.name))
46
+ data = data['f'].astype(np.float32) / N
47
+ data = np.flipud(data.T).T
48
+ data = np.reshape(data, (N, 1, M)).copy()
49
+ dict_new = self.ana_conv_filter.state_dict().copy()
50
+ dict_new['weight'] = torch.from_numpy(data)
51
+ self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
52
+ self.ana_conv_filter.load_state_dict(dict_new)
53
+
54
+ self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
55
+ self.syn_conv_filter = nn.Conv1d(
56
+ N, out_channels=N, kernel_size=M // N, stride=1, bias=False
57
+ )
58
+ gk = load_mat2numpy(op.join(project_root, "h_" + self.name))
59
+ gk = gk['h'].astype(np.float32)
60
+ gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N
61
+ gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
62
+ dict_new = self.syn_conv_filter.state_dict().copy()
63
+ dict_new['weight'] = torch.from_numpy(gk)
64
+ self.syn_conv_filter.load_state_dict(dict_new)
65
+
66
+ for param in self.parameters():
67
+ param.requires_grad = False
68
+
69
+ def __analysis_channel(self, inputs):
70
+ return self.ana_conv_filter(self.ana_pad(inputs))
71
+
72
+ def __systhesis_channel(self, inputs):
73
+ ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1)
74
+ return torch.reshape(ret, (ret.shape[0], 1, -1))
75
+
76
+ def analysis(self, inputs):
77
+ '''
78
+ :param inputs: [batchsize,channel,raw_wav],value:[0,1]
79
+ :return:
80
+ '''
81
+ inputs = F.pad(inputs, ((0, self.pad_samples)))
82
+ ret = None
83
+ for i in range(inputs.size()[1]): # channels
84
+ if ret is None:
85
+ ret = self.__analysis_channel(inputs[:, i : i + 1, :])
86
+ else:
87
+ ret = torch.cat(
88
+ (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1
89
+ )
90
+ return ret
91
+
92
+ def synthesis(self, data):
93
+ '''
94
+ :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1]
95
+ :return:
96
+ '''
97
+ ret = None
98
+ # data = F.pad(data,((0,self.pad_samples//self.N)))
99
+ for i in range(data.size()[1]): # channels
100
+ if i % self.N == 0:
101
+ if ret is None:
102
+ ret = self.__systhesis_channel(data[:, i : i + self.N, :])
103
+ else:
104
+ new = self.__systhesis_channel(data[:, i : i + self.N, :])
105
+ ret = torch.cat((ret, new), dim=1)
106
+ ret = ret[..., : -self.pad_samples]
107
+ return ret
108
+
109
+ def forward(self, inputs):
110
+ return self.ana_conv_filter(self.ana_pad(inputs))
111
+
112
+
113
+ if __name__ == "__main__":
114
+ import torch
115
+ import numpy as np
116
+ import matplotlib.pyplot as plt
117
+ from tools.file.wav import *
118
+
119
+ pqmf = PQMF(N=4, M=64, project_root="/Users/admin/Documents/projects")
120
+
121
+ rs = np.random.RandomState(0)
122
+ x = torch.tensor(rs.rand(4, 2, 32000), dtype=torch.float32)
123
+
124
+ a1 = pqmf.analysis(x)
125
+ a2 = pqmf.synthesis(a1)
126
+
127
+ print(a2.size(), x.size())
128
+
129
+ plt.subplot(211)
130
+ plt.plot(x[0, 0, -500:])
131
+ plt.subplot(212)
132
+ plt.plot(a2[0, 0, -500:])
133
+ plt.plot(x[0, 0, -500:] - a2[0, 0, -500:])
134
+ plt.show()
135
+
136
+ print(torch.sum(torch.abs(x[...] - a2[...])))
bytesep/models/unet.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, List, NoReturn, Tuple
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from torchlibrosa.stft import ISTFT, STFT, magphase
13
+
14
+ from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer
15
+
16
+
17
+ class ConvBlock(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_channels: int,
21
+ out_channels: int,
22
+ kernel_size: Tuple,
23
+ activation: str,
24
+ momentum: float,
25
+ ):
26
+ r"""Convolutional block."""
27
+ super(ConvBlock, self).__init__()
28
+
29
+ self.activation = activation
30
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
31
+
32
+ self.conv1 = nn.Conv2d(
33
+ in_channels=in_channels,
34
+ out_channels=out_channels,
35
+ kernel_size=kernel_size,
36
+ stride=(1, 1),
37
+ dilation=(1, 1),
38
+ padding=padding,
39
+ bias=False,
40
+ )
41
+
42
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
43
+
44
+ self.conv2 = nn.Conv2d(
45
+ in_channels=out_channels,
46
+ out_channels=out_channels,
47
+ kernel_size=kernel_size,
48
+ stride=(1, 1),
49
+ dilation=(1, 1),
50
+ padding=padding,
51
+ bias=False,
52
+ )
53
+
54
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
55
+
56
+ self.init_weights()
57
+
58
+ def init_weights(self) -> NoReturn:
59
+ r"""Initialize weights."""
60
+ init_layer(self.conv1)
61
+ init_layer(self.conv2)
62
+ init_bn(self.bn1)
63
+ init_bn(self.bn2)
64
+
65
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
66
+ r"""Forward data into the module.
67
+
68
+ Args:
69
+ input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
70
+
71
+ Returns:
72
+ output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
73
+ """
74
+ x = act(self.bn1(self.conv1(input_tensor)), self.activation)
75
+ x = act(self.bn2(self.conv2(x)), self.activation)
76
+ output_tensor = x
77
+
78
+ return output_tensor
79
+
80
+
81
+ class EncoderBlock(nn.Module):
82
+ def __init__(
83
+ self,
84
+ in_channels: int,
85
+ out_channels: int,
86
+ kernel_size: Tuple,
87
+ downsample: Tuple,
88
+ activation: str,
89
+ momentum: float,
90
+ ):
91
+ r"""Encoder block."""
92
+ super(EncoderBlock, self).__init__()
93
+
94
+ self.conv_block = ConvBlock(
95
+ in_channels, out_channels, kernel_size, activation, momentum
96
+ )
97
+ self.downsample = downsample
98
+
99
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
100
+ r"""Forward data into the module.
101
+
102
+ Args:
103
+ input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
104
+
105
+ Returns:
106
+ encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins)
107
+ encoder: (batch_size, out_feature_maps, time_steps, freq_bins)
108
+ """
109
+ encoder_tensor = self.conv_block(input_tensor)
110
+ # encoder: (batch_size, out_feature_maps, time_steps, freq_bins)
111
+
112
+ encoder_pool = F.avg_pool2d(encoder_tensor, kernel_size=self.downsample)
113
+ # encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins)
114
+
115
+ return encoder_pool, encoder_tensor
116
+
117
+
118
+ class DecoderBlock(nn.Module):
119
+ def __init__(
120
+ self,
121
+ in_channels: int,
122
+ out_channels: int,
123
+ kernel_size: Tuple,
124
+ upsample: Tuple,
125
+ activation: str,
126
+ momentum: float,
127
+ ):
128
+ r"""Decoder block."""
129
+ super(DecoderBlock, self).__init__()
130
+
131
+ self.kernel_size = kernel_size
132
+ self.stride = upsample
133
+ self.activation = activation
134
+
135
+ self.conv1 = torch.nn.ConvTranspose2d(
136
+ in_channels=in_channels,
137
+ out_channels=out_channels,
138
+ kernel_size=self.stride,
139
+ stride=self.stride,
140
+ padding=(0, 0),
141
+ bias=False,
142
+ dilation=(1, 1),
143
+ )
144
+
145
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
146
+
147
+ self.conv_block2 = ConvBlock(
148
+ out_channels * 2, out_channels, kernel_size, activation, momentum
149
+ )
150
+
151
+ self.init_weights()
152
+
153
+ def init_weights(self):
154
+ r"""Initialize weights."""
155
+ init_layer(self.conv1)
156
+ init_bn(self.bn1)
157
+
158
+ def forward(
159
+ self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor
160
+ ) -> torch.Tensor:
161
+ r"""Forward data into the module.
162
+
163
+ Args:
164
+ torch_tensor: (batch_size, in_feature_maps, downsampled_time_steps, downsampled_freq_bins)
165
+ concat_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
166
+
167
+ Returns:
168
+ output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
169
+ """
170
+ x = act(self.bn1(self.conv1(input_tensor)), self.activation)
171
+ # (batch_size, in_feature_maps, time_steps, freq_bins)
172
+
173
+ x = torch.cat((x, concat_tensor), dim=1)
174
+ # (batch_size, in_feature_maps * 2, time_steps, freq_bins)
175
+
176
+ output_tensor = self.conv_block2(x)
177
+ # output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
178
+
179
+ return output_tensor
180
+
181
+
182
+ class UNet(nn.Module, Base):
183
+ def __init__(self, input_channels: int, target_sources_num: int):
184
+ r"""UNet."""
185
+ super(UNet, self).__init__()
186
+
187
+ self.input_channels = input_channels
188
+ self.target_sources_num = target_sources_num
189
+
190
+ window_size = 2048
191
+ hop_size = 441
192
+ center = True
193
+ pad_mode = "reflect"
194
+ window = "hann"
195
+ activation = "leaky_relu"
196
+ momentum = 0.01
197
+
198
+ self.subbands_num = 1
199
+
200
+ assert (
201
+ self.subbands_num == 1
202
+ ), "Using subbands_num > 1 on spectrogram \
203
+ will lead to unexpected performance sometimes. Suggest to use \
204
+ subband method on waveform."
205
+
206
+ self.K = 3 # outputs: |M|, cos∠M, sin∠M
207
+ self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
208
+
209
+ self.stft = STFT(
210
+ n_fft=window_size,
211
+ hop_length=hop_size,
212
+ win_length=window_size,
213
+ window=window,
214
+ center=center,
215
+ pad_mode=pad_mode,
216
+ freeze_parameters=True,
217
+ )
218
+
219
+ self.istft = ISTFT(
220
+ n_fft=window_size,
221
+ hop_length=hop_size,
222
+ win_length=window_size,
223
+ window=window,
224
+ center=center,
225
+ pad_mode=pad_mode,
226
+ freeze_parameters=True,
227
+ )
228
+
229
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
230
+
231
+ self.subband = Subband(subbands_num=self.subbands_num)
232
+
233
+ self.encoder_block1 = EncoderBlock(
234
+ in_channels=input_channels * self.subbands_num,
235
+ out_channels=32,
236
+ kernel_size=(3, 3),
237
+ downsample=(2, 2),
238
+ activation=activation,
239
+ momentum=momentum,
240
+ )
241
+ self.encoder_block2 = EncoderBlock(
242
+ in_channels=32,
243
+ out_channels=64,
244
+ kernel_size=(3, 3),
245
+ downsample=(2, 2),
246
+ activation=activation,
247
+ momentum=momentum,
248
+ )
249
+ self.encoder_block3 = EncoderBlock(
250
+ in_channels=64,
251
+ out_channels=128,
252
+ kernel_size=(3, 3),
253
+ downsample=(2, 2),
254
+ activation=activation,
255
+ momentum=momentum,
256
+ )
257
+ self.encoder_block4 = EncoderBlock(
258
+ in_channels=128,
259
+ out_channels=256,
260
+ kernel_size=(3, 3),
261
+ downsample=(2, 2),
262
+ activation=activation,
263
+ momentum=momentum,
264
+ )
265
+ self.encoder_block5 = EncoderBlock(
266
+ in_channels=256,
267
+ out_channels=384,
268
+ kernel_size=(3, 3),
269
+ downsample=(2, 2),
270
+ activation=activation,
271
+ momentum=momentum,
272
+ )
273
+ self.encoder_block6 = EncoderBlock(
274
+ in_channels=384,
275
+ out_channels=384,
276
+ kernel_size=(3, 3),
277
+ downsample=(2, 2),
278
+ activation=activation,
279
+ momentum=momentum,
280
+ )
281
+ self.conv_block7 = ConvBlock(
282
+ in_channels=384,
283
+ out_channels=384,
284
+ kernel_size=(3, 3),
285
+ activation=activation,
286
+ momentum=momentum,
287
+ )
288
+ self.decoder_block1 = DecoderBlock(
289
+ in_channels=384,
290
+ out_channels=384,
291
+ kernel_size=(3, 3),
292
+ upsample=(2, 2),
293
+ activation=activation,
294
+ momentum=momentum,
295
+ )
296
+ self.decoder_block2 = DecoderBlock(
297
+ in_channels=384,
298
+ out_channels=384,
299
+ kernel_size=(3, 3),
300
+ upsample=(2, 2),
301
+ activation=activation,
302
+ momentum=momentum,
303
+ )
304
+ self.decoder_block3 = DecoderBlock(
305
+ in_channels=384,
306
+ out_channels=256,
307
+ kernel_size=(3, 3),
308
+ upsample=(2, 2),
309
+ activation=activation,
310
+ momentum=momentum,
311
+ )
312
+ self.decoder_block4 = DecoderBlock(
313
+ in_channels=256,
314
+ out_channels=128,
315
+ kernel_size=(3, 3),
316
+ upsample=(2, 2),
317
+ activation=activation,
318
+ momentum=momentum,
319
+ )
320
+ self.decoder_block5 = DecoderBlock(
321
+ in_channels=128,
322
+ out_channels=64,
323
+ kernel_size=(3, 3),
324
+ upsample=(2, 2),
325
+ activation=activation,
326
+ momentum=momentum,
327
+ )
328
+
329
+ self.decoder_block6 = DecoderBlock(
330
+ in_channels=64,
331
+ out_channels=32,
332
+ kernel_size=(3, 3),
333
+ upsample=(2, 2),
334
+ activation=activation,
335
+ momentum=momentum,
336
+ )
337
+
338
+ self.after_conv_block1 = ConvBlock(
339
+ in_channels=32,
340
+ out_channels=32,
341
+ kernel_size=(3, 3),
342
+ activation=activation,
343
+ momentum=momentum,
344
+ )
345
+
346
+ self.after_conv2 = nn.Conv2d(
347
+ in_channels=32,
348
+ out_channels=target_sources_num
349
+ * input_channels
350
+ * self.K
351
+ * self.subbands_num,
352
+ kernel_size=(1, 1),
353
+ stride=(1, 1),
354
+ padding=(0, 0),
355
+ bias=True,
356
+ )
357
+
358
+ self.init_weights()
359
+
360
+ def init_weights(self):
361
+ r"""Initialize weights."""
362
+ init_bn(self.bn0)
363
+ init_layer(self.after_conv2)
364
+
365
+ def feature_maps_to_wav(
366
+ self,
367
+ input_tensor: torch.Tensor,
368
+ sp: torch.Tensor,
369
+ sin_in: torch.Tensor,
370
+ cos_in: torch.Tensor,
371
+ audio_length: int,
372
+ ) -> torch.Tensor:
373
+ r"""Convert feature maps to waveform.
374
+
375
+ Args:
376
+ input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
377
+ sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
378
+ sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
379
+ cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
380
+
381
+ Outputs:
382
+ waveform: (batch_size, target_sources_num * input_channels, segment_samples)
383
+ """
384
+ batch_size, _, time_steps, freq_bins = input_tensor.shape
385
+
386
+ x = input_tensor.reshape(
387
+ batch_size,
388
+ self.target_sources_num,
389
+ self.input_channels,
390
+ self.K,
391
+ time_steps,
392
+ freq_bins,
393
+ )
394
+ # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
395
+
396
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
397
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
398
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
399
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
400
+ # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
401
+
402
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
403
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
404
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
405
+ out_cos = (
406
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
407
+ )
408
+ out_sin = (
409
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
410
+ )
411
+ # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
412
+ # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
413
+
414
+ # Calculate |Y|.
415
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
416
+ # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
417
+
418
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
419
+ out_real = out_mag * out_cos
420
+ out_imag = out_mag * out_sin
421
+ # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
422
+
423
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
424
+ shape = (
425
+ batch_size * self.target_sources_num * self.input_channels,
426
+ 1,
427
+ time_steps,
428
+ freq_bins,
429
+ )
430
+ out_real = out_real.reshape(shape)
431
+ out_imag = out_imag.reshape(shape)
432
+
433
+ # ISTFT.
434
+ x = self.istft(out_real, out_imag, audio_length)
435
+ # (batch_size * target_sources_num * input_channels, segments_num)
436
+
437
+ # Reshape.
438
+ waveform = x.reshape(
439
+ batch_size, self.target_sources_num * self.input_channels, audio_length
440
+ )
441
+ # (batch_size, target_sources_num * input_channels, segments_num)
442
+
443
+ return waveform
444
+
445
+ def forward(self, input_dict: Dict) -> Dict:
446
+ r"""Forward data into the module.
447
+
448
+ Args:
449
+ input_dict: dict, e.g., {
450
+ waveform: (batch_size, input_channels, segment_samples),
451
+ ...,
452
+ }
453
+
454
+ Outputs:
455
+ output_dict: dict, e.g., {
456
+ 'waveform': (batch_size, input_channels, segment_samples),
457
+ ...,
458
+ }
459
+ """
460
+ mixtures = input_dict['waveform']
461
+ # (batch_size, input_channels, segment_samples)
462
+
463
+ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
464
+ # mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
465
+
466
+ # Batch normalize on individual frequency bins.
467
+ x = mag.transpose(1, 3)
468
+ x = self.bn0(x)
469
+ x = x.transpose(1, 3)
470
+ # x: (batch_size, input_channels, time_steps, freq_bins)
471
+
472
+ # Pad spectrogram to be evenly divided by downsample ratio.
473
+ origin_len = x.shape[2]
474
+ pad_len = (
475
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
476
+ - origin_len
477
+ )
478
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
479
+ # x: (batch_size, input_channels, padded_time_steps, freq_bins)
480
+
481
+ # Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024
482
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
483
+
484
+ if self.subbands_num > 1:
485
+ x = self.subband.analysis(x)
486
+ # (bs, input_channels, T, F'), where F' = F // subbands_num
487
+
488
+ # UNet
489
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2)
490
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4)
491
+ (x3_pool, x3) = self.encoder_block3(
492
+ x2_pool
493
+ ) # x3_pool: (bs, 128, T / 8, F' / 8)
494
+ (x4_pool, x4) = self.encoder_block4(
495
+ x3_pool
496
+ ) # x4_pool: (bs, 256, T / 16, F' / 16)
497
+ (x5_pool, x5) = self.encoder_block5(
498
+ x4_pool
499
+ ) # x5_pool: (bs, 384, T / 32, F' / 32)
500
+ (x6_pool, x6) = self.encoder_block6(
501
+ x5_pool
502
+ ) # x6_pool: (bs, 384, T / 64, F' / 64)
503
+ x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64)
504
+ x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32)
505
+ x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16)
506
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8)
507
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4)
508
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2)
509
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F')
510
+ x = self.after_conv_block1(x12) # (bs, 32, T, F')
511
+
512
+ x = self.after_conv2(x)
513
+ # (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F')
514
+
515
+ if self.subbands_num > 1:
516
+ x = self.subband.synthesis(x)
517
+ # (batch_size, target_sources_num * input_channles * self.K, T, F)
518
+
519
+ # Recover shape
520
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
521
+
522
+ x = x[:, :, 0:origin_len, :]
523
+ # (batch_size, target_sources_num * input_channles * self.K, T, F)
524
+
525
+ audio_length = mixtures.shape[2]
526
+
527
+ separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
528
+ # separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
529
+
530
+ output_dict = {'waveform': separated_audio}
531
+
532
+ return output_dict
bytesep/models/unet_subbandtime.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchlibrosa.stft import ISTFT, STFT, magphase
8
+
9
+ from bytesep.models.pytorch_modules import Base, init_bn, init_layer
10
+ from bytesep.models.subband_tools.pqmf import PQMF
11
+ from bytesep.models.unet import ConvBlock, DecoderBlock, EncoderBlock
12
+
13
+
14
+ class UNetSubbandTime(nn.Module, Base):
15
+ def __init__(self, input_channels: int, target_sources_num: int):
16
+ r"""Subband waveform UNet."""
17
+ super(UNetSubbandTime, self).__init__()
18
+
19
+ self.input_channels = input_channels
20
+ self.target_sources_num = target_sources_num
21
+
22
+ window_size = 512 # 2048 // 4
23
+ hop_size = 110 # 441 // 4
24
+ center = True
25
+ pad_mode = "reflect"
26
+ window = "hann"
27
+ activation = "leaky_relu"
28
+ momentum = 0.01
29
+
30
+ self.subbands_num = 4
31
+ self.K = 3 # outputs: |M|, cos∠M, sin∠M
32
+
33
+ self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
34
+
35
+ self.pqmf = PQMF(
36
+ N=self.subbands_num,
37
+ M=64,
38
+ project_root='bytesep/models/subband_tools/filters',
39
+ )
40
+
41
+ self.stft = STFT(
42
+ n_fft=window_size,
43
+ hop_length=hop_size,
44
+ win_length=window_size,
45
+ window=window,
46
+ center=center,
47
+ pad_mode=pad_mode,
48
+ freeze_parameters=True,
49
+ )
50
+
51
+ self.istft = ISTFT(
52
+ n_fft=window_size,
53
+ hop_length=hop_size,
54
+ win_length=window_size,
55
+ window=window,
56
+ center=center,
57
+ pad_mode=pad_mode,
58
+ freeze_parameters=True,
59
+ )
60
+
61
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
62
+
63
+ self.encoder_block1 = EncoderBlock(
64
+ in_channels=input_channels * self.subbands_num,
65
+ out_channels=32,
66
+ kernel_size=(3, 3),
67
+ downsample=(2, 2),
68
+ activation=activation,
69
+ momentum=momentum,
70
+ )
71
+ self.encoder_block2 = EncoderBlock(
72
+ in_channels=32,
73
+ out_channels=64,
74
+ kernel_size=(3, 3),
75
+ downsample=(2, 2),
76
+ activation=activation,
77
+ momentum=momentum,
78
+ )
79
+ self.encoder_block3 = EncoderBlock(
80
+ in_channels=64,
81
+ out_channels=128,
82
+ kernel_size=(3, 3),
83
+ downsample=(2, 2),
84
+ activation=activation,
85
+ momentum=momentum,
86
+ )
87
+ self.encoder_block4 = EncoderBlock(
88
+ in_channels=128,
89
+ out_channels=256,
90
+ kernel_size=(3, 3),
91
+ downsample=(2, 2),
92
+ activation=activation,
93
+ momentum=momentum,
94
+ )
95
+ self.encoder_block5 = EncoderBlock(
96
+ in_channels=256,
97
+ out_channels=384,
98
+ kernel_size=(3, 3),
99
+ downsample=(2, 2),
100
+ activation=activation,
101
+ momentum=momentum,
102
+ )
103
+ self.encoder_block6 = EncoderBlock(
104
+ in_channels=384,
105
+ out_channels=384,
106
+ kernel_size=(3, 3),
107
+ downsample=(2, 2),
108
+ activation=activation,
109
+ momentum=momentum,
110
+ )
111
+ self.conv_block7 = ConvBlock(
112
+ in_channels=384,
113
+ out_channels=384,
114
+ kernel_size=(3, 3),
115
+ activation=activation,
116
+ momentum=momentum,
117
+ )
118
+ self.decoder_block1 = DecoderBlock(
119
+ in_channels=384,
120
+ out_channels=384,
121
+ kernel_size=(3, 3),
122
+ upsample=(2, 2),
123
+ activation=activation,
124
+ momentum=momentum,
125
+ )
126
+ self.decoder_block2 = DecoderBlock(
127
+ in_channels=384,
128
+ out_channels=384,
129
+ kernel_size=(3, 3),
130
+ upsample=(2, 2),
131
+ activation=activation,
132
+ momentum=momentum,
133
+ )
134
+ self.decoder_block3 = DecoderBlock(
135
+ in_channels=384,
136
+ out_channels=256,
137
+ kernel_size=(3, 3),
138
+ upsample=(2, 2),
139
+ activation=activation,
140
+ momentum=momentum,
141
+ )
142
+ self.decoder_block4 = DecoderBlock(
143
+ in_channels=256,
144
+ out_channels=128,
145
+ kernel_size=(3, 3),
146
+ upsample=(2, 2),
147
+ activation=activation,
148
+ momentum=momentum,
149
+ )
150
+ self.decoder_block5 = DecoderBlock(
151
+ in_channels=128,
152
+ out_channels=64,
153
+ kernel_size=(3, 3),
154
+ upsample=(2, 2),
155
+ activation=activation,
156
+ momentum=momentum,
157
+ )
158
+
159
+ self.decoder_block6 = DecoderBlock(
160
+ in_channels=64,
161
+ out_channels=32,
162
+ kernel_size=(3, 3),
163
+ upsample=(2, 2),
164
+ activation=activation,
165
+ momentum=momentum,
166
+ )
167
+
168
+ self.after_conv_block1 = ConvBlock(
169
+ in_channels=32,
170
+ out_channels=32,
171
+ kernel_size=(3, 3),
172
+ activation=activation,
173
+ momentum=momentum,
174
+ )
175
+
176
+ self.after_conv2 = nn.Conv2d(
177
+ in_channels=32,
178
+ out_channels=target_sources_num
179
+ * input_channels
180
+ * self.K
181
+ * self.subbands_num,
182
+ kernel_size=(1, 1),
183
+ stride=(1, 1),
184
+ padding=(0, 0),
185
+ bias=True,
186
+ )
187
+
188
+ self.init_weights()
189
+
190
+ def init_weights(self):
191
+ r"""Initialize weights."""
192
+ init_bn(self.bn0)
193
+ init_layer(self.after_conv2)
194
+
195
+ def feature_maps_to_wav(
196
+ self,
197
+ input_tensor: torch.Tensor,
198
+ sp: torch.Tensor,
199
+ sin_in: torch.Tensor,
200
+ cos_in: torch.Tensor,
201
+ audio_length: int,
202
+ ) -> torch.Tensor:
203
+ r"""Convert feature maps to waveform.
204
+
205
+ Args:
206
+ input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
207
+ sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
208
+ sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
209
+ cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
210
+
211
+ Outputs:
212
+ waveform: (batch_size, target_sources_num * input_channels, segment_samples)
213
+ """
214
+ batch_size, _, time_steps, freq_bins = input_tensor.shape
215
+
216
+ x = input_tensor.reshape(
217
+ batch_size,
218
+ self.target_sources_num,
219
+ self.input_channels,
220
+ self.K,
221
+ time_steps,
222
+ freq_bins,
223
+ )
224
+ # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
225
+
226
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
227
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
228
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
229
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
230
+ # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
231
+
232
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
233
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
234
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
235
+ out_cos = (
236
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
237
+ )
238
+ out_sin = (
239
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
240
+ )
241
+ # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
242
+ # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
243
+
244
+ # Calculate |Y|.
245
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
246
+ # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
247
+
248
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
249
+ out_real = out_mag * out_cos
250
+ out_imag = out_mag * out_sin
251
+ # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
252
+
253
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
254
+ shape = (
255
+ batch_size * self.target_sources_num * self.input_channels,
256
+ 1,
257
+ time_steps,
258
+ freq_bins,
259
+ )
260
+ out_real = out_real.reshape(shape)
261
+ out_imag = out_imag.reshape(shape)
262
+
263
+ # ISTFT.
264
+ x = self.istft(out_real, out_imag, audio_length)
265
+ # (batch_size * target_sources_num * input_channels, segments_num)
266
+
267
+ # Reshape.
268
+ waveform = x.reshape(
269
+ batch_size, self.target_sources_num * self.input_channels, audio_length
270
+ )
271
+ # (batch_size, target_sources_num * input_channels, segments_num)
272
+
273
+ return waveform
274
+
275
+ def forward(self, input_dict: Dict) -> Dict:
276
+ """Forward data into the module.
277
+
278
+ Args:
279
+ input_dict: dict, e.g., {
280
+ waveform: (batch_size, input_channels, segment_samples),
281
+ ...,
282
+ }
283
+
284
+ Outputs:
285
+ output_dict: dict, e.g., {
286
+ 'waveform': (batch_size, input_channels, segment_samples),
287
+ ...,
288
+ }
289
+ """
290
+ mixtures = input_dict['waveform']
291
+ # (batch_size, input_channels, segment_samples)
292
+
293
+ if self.subbands_num > 1:
294
+ subband_x = self.pqmf.analysis(mixtures)
295
+ # -- subband_x: (batch_size, input_channels * subbands_num, segment_samples)
296
+ # -- subband_x: (batch_size, subbands_num * input_channels, segment_samples)
297
+ else:
298
+ subband_x = mixtures
299
+
300
+ # from IPython import embed; embed(using=False); os._exit(0)
301
+ # import soundfile
302
+ # soundfile.write(file='_zz.wav', data=subband_x.data.cpu().numpy()[0, 2], samplerate=11025)
303
+
304
+ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
305
+ # mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)
306
+
307
+ # Batch normalize on individual frequency bins.
308
+ x = mag.transpose(1, 3)
309
+ x = self.bn0(x)
310
+ x = x.transpose(1, 3)
311
+ # (batch_size, input_channels * subbands_num, time_steps, freq_bins)
312
+
313
+ # Pad spectrogram to be evenly divided by downsample ratio.
314
+ origin_len = x.shape[2]
315
+ pad_len = (
316
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
317
+ - origin_len
318
+ )
319
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
320
+ # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
321
+
322
+ # Let frequency bins be evenly divided by 2, e.g., 257 -> 256
323
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
324
+ # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
325
+
326
+ # UNet
327
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2)
328
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4)
329
+ (x3_pool, x3) = self.encoder_block3(
330
+ x2_pool
331
+ ) # x3_pool: (bs, 128, T / 8, F' / 8)
332
+ (x4_pool, x4) = self.encoder_block4(
333
+ x3_pool
334
+ ) # x4_pool: (bs, 256, T / 16, F' / 16)
335
+ (x5_pool, x5) = self.encoder_block5(
336
+ x4_pool
337
+ ) # x5_pool: (bs, 384, T / 32, F' / 32)
338
+ (x6_pool, x6) = self.encoder_block6(
339
+ x5_pool
340
+ ) # x6_pool: (bs, 384, T / 64, F' / 64)
341
+ x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64)
342
+ x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32)
343
+ x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16)
344
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8)
345
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4)
346
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2)
347
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F')
348
+ x = self.after_conv_block1(x12) # (bs, 32, T, F')
349
+
350
+ x = self.after_conv2(x)
351
+ # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
352
+
353
+ # Recover shape
354
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 256 -> 257.
355
+
356
+ x = x[:, :, 0:origin_len, :]
357
+ # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
358
+
359
+ audio_length = subband_x.shape[2]
360
+
361
+ # Recover each subband spectrograms to subband waveforms. Then synthesis
362
+ # the subband waveforms to a waveform.
363
+ C1 = x.shape[1] // self.subbands_num
364
+ C2 = mag.shape[1] // self.subbands_num
365
+
366
+ separated_subband_audio = torch.cat(
367
+ [
368
+ self.feature_maps_to_wav(
369
+ input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
370
+ sp=mag[:, j * C2 : (j + 1) * C2, :, :],
371
+ sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
372
+ cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
373
+ audio_length=audio_length,
374
+ )
375
+ for j in range(self.subbands_num)
376
+ ],
377
+ dim=1,
378
+ )
379
+ # (batch_size, subbands_num * target_sources_num * input_channles, segment_samples)
380
+
381
+ if self.subbands_num > 1:
382
+ separated_audio = self.pqmf.synthesis(separated_subband_audio)
383
+ # (batch_size, target_sources_num * input_channles, segment_samples)
384
+ else:
385
+ separated_audio = separated_subband_audio
386
+
387
+ output_dict = {'waveform': separated_audio}
388
+
389
+ return output_dict
bytesep/optimizers/__init__.py ADDED
File without changes
bytesep/optimizers/lr_schedulers.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_lr_lambda(step, warm_up_steps: int, reduce_lr_steps: int):
2
+ r"""Get lr_lambda for LambdaLR. E.g.,
3
+
4
+ .. code-block: python
5
+ lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000)
6
+
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ LambdaLR(optimizer, lr_lambda)
9
+
10
+ Args:
11
+ warm_up_steps: int, steps for warm up
12
+ reduce_lr_steps: int, reduce learning rate by 0.9 every #reduce_lr_steps steps
13
+
14
+ Returns:
15
+ learning rate: float
16
+ """
17
+ if step <= warm_up_steps:
18
+ return step / warm_up_steps
19
+ else:
20
+ return 0.9 ** (step // reduce_lr_steps)
bytesep/plot_results/__init__.py ADDED
File without changes
bytesep/plot_results/musdb18.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+
9
+ def load_sdrs(workspace, task_name, filename, config, gpus, source_type):
10
+
11
+ stat_path = os.path.join(
12
+ workspace,
13
+ "statistics",
14
+ task_name,
15
+ filename,
16
+ "config={},gpus={}".format(config, gpus),
17
+ "statistics.pkl",
18
+ )
19
+
20
+ stat_dict = pickle.load(open(stat_path, 'rb'))
21
+
22
+ median_sdrs = [e['median_sdr_dict'][source_type] for e in stat_dict['test']]
23
+
24
+ return median_sdrs
25
+
26
+
27
+ def plot_statistics(args):
28
+
29
+ # arguments & parameters
30
+ workspace = args.workspace
31
+ select = args.select
32
+ task_name = "musdb18"
33
+ filename = "train"
34
+
35
+ # paths
36
+ fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select))
37
+ os.makedirs(os.path.dirname(fig_path), exist_ok=True)
38
+
39
+ linewidth = 1
40
+ lines = []
41
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
42
+
43
+ if select == '1a':
44
+ sdrs = load_sdrs(
45
+ workspace,
46
+ task_name,
47
+ filename,
48
+ config='vocals-accompaniment,unet',
49
+ gpus=1,
50
+ source_type="vocals",
51
+ )
52
+ (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
53
+ lines.append(line)
54
+ ylim = 15
55
+
56
+ elif select == '1b':
57
+ sdrs = load_sdrs(
58
+ workspace,
59
+ task_name,
60
+ filename,
61
+ config='accompaniment-vocals,unet',
62
+ gpus=1,
63
+ source_type="accompaniment",
64
+ )
65
+ (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
66
+ lines.append(line)
67
+ ylim = 20
68
+
69
+ if select == '1c':
70
+ sdrs = load_sdrs(
71
+ workspace,
72
+ task_name,
73
+ filename,
74
+ config='vocals-accompaniment,unet',
75
+ gpus=1,
76
+ source_type="vocals",
77
+ )
78
+ (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
79
+ lines.append(line)
80
+
81
+ sdrs = load_sdrs(
82
+ workspace,
83
+ task_name,
84
+ filename,
85
+ config='vocals-accompaniment,resunet',
86
+ gpus=2,
87
+ source_type="vocals",
88
+ )
89
+ (line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth)
90
+ lines.append(line)
91
+
92
+ sdrs = load_sdrs(
93
+ workspace,
94
+ task_name,
95
+ filename,
96
+ config='vocals-accompaniment,unet_subbandtime',
97
+ gpus=1,
98
+ source_type="vocals",
99
+ )
100
+ (line,) = ax.plot(sdrs, label='unet_subband,l1_wav', linewidth=linewidth)
101
+ lines.append(line)
102
+
103
+ sdrs = load_sdrs(
104
+ workspace,
105
+ task_name,
106
+ filename,
107
+ config='vocals-accompaniment,resunet_subbandtime',
108
+ gpus=1,
109
+ source_type="vocals",
110
+ )
111
+ (line,) = ax.plot(sdrs, label='resunet_subband,l1_wav', linewidth=linewidth)
112
+ lines.append(line)
113
+
114
+ ylim = 15
115
+
116
+ elif select == '1d':
117
+ sdrs = load_sdrs(
118
+ workspace,
119
+ task_name,
120
+ filename,
121
+ config='accompaniment-vocals,unet',
122
+ gpus=1,
123
+ source_type="accompaniment",
124
+ )
125
+ (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
126
+ lines.append(line)
127
+
128
+ sdrs = load_sdrs(
129
+ workspace,
130
+ task_name,
131
+ filename,
132
+ config='accompaniment-vocals,resunet',
133
+ gpus=2,
134
+ source_type="accompaniment",
135
+ )
136
+ (line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth)
137
+ lines.append(line)
138
+
139
+ # sdrs = load_sdrs(
140
+ # workspace,
141
+ # task_name,
142
+ # filename,
143
+ # config='accompaniment-vocals,unet_subbandtime',
144
+ # gpus=1,
145
+ # source_type="accompaniment",
146
+ # )
147
+ # (line,) = ax.plot(sdrs, label='UNet_subbtandtime,l1_wav', linewidth=linewidth)
148
+ # lines.append(line)
149
+
150
+ sdrs = load_sdrs(
151
+ workspace,
152
+ task_name,
153
+ filename,
154
+ config='accompaniment-vocals,resunet_subbandtime',
155
+ gpus=1,
156
+ source_type="accompaniment",
157
+ )
158
+ (line,) = ax.plot(
159
+ sdrs, label='ResUNet_subbtandtime,l1_wav', linewidth=linewidth
160
+ )
161
+ lines.append(line)
162
+
163
+ ylim = 20
164
+
165
+ else:
166
+ raise Exception('Error!')
167
+
168
+ eval_every_iterations = 10000
169
+ total_ticks = 50
170
+ ticks_freq = 10
171
+
172
+ ax.set_ylim(0, ylim)
173
+ ax.set_xlim(0, total_ticks)
174
+ ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq))
175
+ ax.xaxis.set_ticklabels(
176
+ np.arange(
177
+ 0,
178
+ total_ticks * eval_every_iterations + 1,
179
+ ticks_freq * eval_every_iterations,
180
+ )
181
+ )
182
+ ax.yaxis.set_ticks(np.arange(ylim + 1))
183
+ ax.yaxis.set_ticklabels(np.arange(ylim + 1))
184
+ ax.grid(color='b', linestyle='solid', linewidth=0.3)
185
+ plt.legend(handles=lines, loc=4)
186
+
187
+ plt.savefig(fig_path)
188
+ print('Save figure to {}'.format(fig_path))
189
+
190
+
191
+ if __name__ == '__main__':
192
+ parser = argparse.ArgumentParser()
193
+ parser.add_argument('--workspace', type=str, required=True)
194
+ parser.add_argument('--select', type=str, required=True)
195
+
196
+ args = parser.parse_args()
197
+
198
+ plot_statistics(args)
bytesep/plot_results/plot_vctk-musdb18.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import argparse
5
+ import h5py
6
+ import math
7
+ import time
8
+ import logging
9
+ import pickle
10
+ import matplotlib.pyplot as plt
11
+
12
+
13
+ def load_sdrs(workspace, task_name, filename, config, gpus):
14
+
15
+ stat_path = os.path.join(
16
+ workspace,
17
+ "statistics",
18
+ task_name,
19
+ filename,
20
+ "config={},gpus={}".format(config, gpus),
21
+ "statistics.pkl",
22
+ )
23
+
24
+ stat_dict = pickle.load(open(stat_path, 'rb'))
25
+
26
+ median_sdrs = [e['sdr'] for e in stat_dict['test']]
27
+
28
+ return median_sdrs
29
+
30
+
31
+ def plot_statistics(args):
32
+
33
+ # arguments & parameters
34
+ workspace = args.workspace
35
+ select = args.select
36
+ task_name = "vctk-musdb18"
37
+ filename = "train"
38
+
39
+ # paths
40
+ fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select))
41
+ os.makedirs(os.path.dirname(fig_path), exist_ok=True)
42
+
43
+ linewidth = 1
44
+ lines = []
45
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
46
+ ylim = 30
47
+ expand = 1
48
+
49
+ if select == '1a':
50
+ sdrs = load_sdrs(workspace, task_name, filename, config='unet', gpus=1)
51
+ (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
52
+ lines.append(line)
53
+
54
+ else:
55
+ raise Exception('Error!')
56
+
57
+ eval_every_iterations = 10000
58
+ total_ticks = 50
59
+ ticks_freq = 10
60
+
61
+ ax.set_ylim(0, ylim)
62
+ ax.set_xlim(0, total_ticks)
63
+ ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq))
64
+ ax.xaxis.set_ticklabels(
65
+ np.arange(
66
+ 0,
67
+ total_ticks * eval_every_iterations + 1,
68
+ ticks_freq * eval_every_iterations,
69
+ )
70
+ )
71
+ ax.yaxis.set_ticks(np.arange(ylim + 1))
72
+ ax.yaxis.set_ticklabels(np.arange(ylim + 1))
73
+ ax.grid(color='b', linestyle='solid', linewidth=0.3)
74
+ plt.legend(handles=lines, loc=4)
75
+
76
+ plt.savefig(fig_path)
77
+ print('Save figure to {}'.format(fig_path))
78
+
79
+
80
+ if __name__ == '__main__':
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument('--workspace', type=str, required=True)
83
+ parser.add_argument('--select', type=str, required=True)
84
+
85
+ args = parser.parse_args()
86
+
87
+ plot_statistics(args)