Spaces:
Runtime error
Runtime error
jone
commited on
Commit
•
75c6e9a
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +29 -0
- LICENSE +13 -0
- README.md +33 -0
- app.py +40 -0
- bytesep/__init__.py +1 -0
- bytesep/callbacks/__init__.py +76 -0
- bytesep/callbacks/base_callbacks.py +44 -0
- bytesep/callbacks/instruments_callbacks.py +200 -0
- bytesep/callbacks/musdb18.py +485 -0
- bytesep/callbacks/voicebank_demand.py +231 -0
- bytesep/data/__init__.py +0 -0
- bytesep/data/augmentors.py +157 -0
- bytesep/data/batch_data_preprocessors.py +141 -0
- bytesep/data/data_modules.py +187 -0
- bytesep/data/samplers.py +188 -0
- bytesep/dataset_creation/__init__.py +0 -0
- bytesep/dataset_creation/create_evaluation_audios/__init__.py +0 -0
- bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py +160 -0
- bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py +164 -0
- bytesep/dataset_creation/create_evaluation_audios/violin-piano.py +162 -0
- bytesep/dataset_creation/create_indexes/__init__.py +0 -0
- bytesep/dataset_creation/create_indexes/create_indexes.py +142 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/__init__.py +0 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py +173 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py +136 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py +207 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py +114 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py +143 -0
- bytesep/inference.py +404 -0
- bytesep/inference_many.py +163 -0
- bytesep/losses.py +106 -0
- bytesep/models/__init__.py +0 -0
- bytesep/models/conditional_unet.py +496 -0
- bytesep/models/lightning_modules.py +188 -0
- bytesep/models/pytorch_modules.py +204 -0
- bytesep/models/resunet.py +516 -0
- bytesep/models/resunet_ismir2021.py +534 -0
- bytesep/models/resunet_subbandtime.py +545 -0
- bytesep/models/subband_tools/__init__.py +0 -0
- bytesep/models/subband_tools/fDomainHelper.py +255 -0
- bytesep/models/subband_tools/filters/f_4_64.mat +0 -0
- bytesep/models/subband_tools/filters/h_4_64.mat +0 -0
- bytesep/models/subband_tools/pqmf.py +136 -0
- bytesep/models/unet.py +532 -0
- bytesep/models/unet_subbandtime.py +389 -0
- bytesep/optimizers/__init__.py +0 -0
- bytesep/optimizers/lr_schedulers.py +20 -0
- bytesep/plot_results/__init__.py +0 -0
- bytesep/plot_results/musdb18.py +198 -0
- bytesep/plot_results/plot_vctk-musdb18.py +87 -0
.gitattributes
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
29 |
+
example.wav filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2021 ByteDance
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Music_Source_Separation
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
|
11 |
+
# Configuration
|
12 |
+
|
13 |
+
`title`: _string_
|
14 |
+
Display title for the Space
|
15 |
+
|
16 |
+
`emoji`: _string_
|
17 |
+
Space emoji (emoji-only character allowed)
|
18 |
+
|
19 |
+
`colorFrom`: _string_
|
20 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
21 |
+
|
22 |
+
`colorTo`: _string_
|
23 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
24 |
+
|
25 |
+
`sdk`: _string_
|
26 |
+
Can be either `gradio` or `streamlit`
|
27 |
+
|
28 |
+
`app_file`: _string_
|
29 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
30 |
+
Path is relative to the root of the repository.
|
31 |
+
|
32 |
+
`pinned`: _boolean_
|
33 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.system('pip install gradio==2.3.0a0')
|
3 |
+
os.system('pip freeze')
|
4 |
+
import sys
|
5 |
+
sys.path.append('.')
|
6 |
+
import gradio as gr
|
7 |
+
os.system('pip install -U torchtext==0.8.0')
|
8 |
+
#os.system('python setup.py install --install-dir .')
|
9 |
+
from scipy.io import wavfile
|
10 |
+
|
11 |
+
os.system('./separate_scripts/download_checkpoints.sh')
|
12 |
+
|
13 |
+
def inference(audio):
|
14 |
+
# read the file and get the sample rate and data
|
15 |
+
rate, data = wavfile.read(audio.name)
|
16 |
+
|
17 |
+
# save the result
|
18 |
+
wavfile.write('foo_left.wav', rate, data)
|
19 |
+
os.system("""python bytesep/inference.py --config_yaml=./scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_subbandtime.yaml --checkpoint_path=./downloaded_checkpoints/resunet143_subbtandtime_vocals_8.8dB_350k_steps.pth --audio_path=foo_left.wav --output_path=sep_vocals.mp3""")
|
20 |
+
#os.system('./separate_scripts/separate_vocals.sh ' + audio.name + ' "sep_vocals.mp3"')
|
21 |
+
os.system("""python bytesep/inference.py --config_yaml=./scripts/4_train/musdb18/configs/accompaniment-vocals,resunet_subbandtime.yaml --checkpoint_path=./downloaded_checkpoints/resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth --audio_path=foo_left.wav --output_path=sep_accompaniment.mp3""")
|
22 |
+
#os.system('./separate_scripts/separate_accompaniment.sh ' + audio.name + ' "sep_accompaniment.mp3"')
|
23 |
+
#os.system('python separate_scripts/separate.py --audio_path=' +audio.name+' --source_type="accompaniment"')
|
24 |
+
#os.system('python separate_scripts/separate.py --audio_path=' +audio.name+' --source_type="vocals"')
|
25 |
+
return 'sep_vocals.mp3', 'sep_accompaniment.mp3'
|
26 |
+
title = "Music Source Separation"
|
27 |
+
description = "Gradio demo for Music Source Separation. To use it, simply add your audio, or click one of the examples to load them. Currently supports .wav files. Read more at the links below."
|
28 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.05418'>Decoupling Magnitude and Phase Estimation with Deep ResUNet for Music Source Separation</a> | <a href='https://github.com/bytedance/music_source_separation'>Github Repo</a></p>"
|
29 |
+
|
30 |
+
examples = [['example.wav']]
|
31 |
+
gr.Interface(
|
32 |
+
inference,
|
33 |
+
gr.inputs.Audio(type="file", label="Input"),
|
34 |
+
[gr.outputs.Audio(type="file", label="Vocals"),gr.outputs.Audio(type="file", label="Accompaniment")],
|
35 |
+
title=title,
|
36 |
+
description=description,
|
37 |
+
article=article,
|
38 |
+
enable_queue=True,
|
39 |
+
examples=examples
|
40 |
+
).launch(debug=True)
|
bytesep/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from bytesep.inference import Separator
|
bytesep/callbacks/__init__.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
def get_callbacks(
|
8 |
+
task_name: str,
|
9 |
+
config_yaml: str,
|
10 |
+
workspace: str,
|
11 |
+
checkpoints_dir: str,
|
12 |
+
statistics_path: str,
|
13 |
+
logger: pl.loggers.TensorBoardLogger,
|
14 |
+
model: nn.Module,
|
15 |
+
evaluate_device: str,
|
16 |
+
) -> List[pl.Callback]:
|
17 |
+
r"""Get callbacks of a task and config yaml file.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
task_name: str
|
21 |
+
config_yaml: str
|
22 |
+
dataset_dir: str
|
23 |
+
workspace: str, containing useful files such as audios for evaluation
|
24 |
+
checkpoints_dir: str, directory to save checkpoints
|
25 |
+
statistics_dir: str, directory to save statistics
|
26 |
+
logger: pl.loggers.TensorBoardLogger
|
27 |
+
model: nn.Module
|
28 |
+
evaluate_device: str
|
29 |
+
|
30 |
+
Return:
|
31 |
+
callbacks: List[pl.Callback]
|
32 |
+
"""
|
33 |
+
if task_name == 'musdb18':
|
34 |
+
|
35 |
+
from bytesep.callbacks.musdb18 import get_musdb18_callbacks
|
36 |
+
|
37 |
+
return get_musdb18_callbacks(
|
38 |
+
config_yaml=config_yaml,
|
39 |
+
workspace=workspace,
|
40 |
+
checkpoints_dir=checkpoints_dir,
|
41 |
+
statistics_path=statistics_path,
|
42 |
+
logger=logger,
|
43 |
+
model=model,
|
44 |
+
evaluate_device=evaluate_device,
|
45 |
+
)
|
46 |
+
|
47 |
+
elif task_name == 'voicebank-demand':
|
48 |
+
|
49 |
+
from bytesep.callbacks.voicebank_demand import get_voicebank_demand_callbacks
|
50 |
+
|
51 |
+
return get_voicebank_demand_callbacks(
|
52 |
+
config_yaml=config_yaml,
|
53 |
+
workspace=workspace,
|
54 |
+
checkpoints_dir=checkpoints_dir,
|
55 |
+
statistics_path=statistics_path,
|
56 |
+
logger=logger,
|
57 |
+
model=model,
|
58 |
+
evaluate_device=evaluate_device,
|
59 |
+
)
|
60 |
+
|
61 |
+
elif task_name in ['vctk-musdb18', 'violin-piano', 'piano-symphony']:
|
62 |
+
|
63 |
+
from bytesep.callbacks.instruments_callbacks import get_instruments_callbacks
|
64 |
+
|
65 |
+
return get_instruments_callbacks(
|
66 |
+
config_yaml=config_yaml,
|
67 |
+
workspace=workspace,
|
68 |
+
checkpoints_dir=checkpoints_dir,
|
69 |
+
statistics_path=statistics_path,
|
70 |
+
logger=logger,
|
71 |
+
model=model,
|
72 |
+
evaluate_device=evaluate_device,
|
73 |
+
)
|
74 |
+
|
75 |
+
else:
|
76 |
+
raise NotImplementedError
|
bytesep/callbacks/base_callbacks.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import NoReturn
|
4 |
+
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from pytorch_lightning.utilities import rank_zero_only
|
9 |
+
|
10 |
+
|
11 |
+
class SaveCheckpointsCallback(pl.Callback):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
model: nn.Module,
|
15 |
+
checkpoints_dir: str,
|
16 |
+
save_step_frequency: int,
|
17 |
+
):
|
18 |
+
r"""Callback to save checkpoints every #save_step_frequency steps.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model: nn.Module
|
22 |
+
checkpoints_dir: str, directory to save checkpoints
|
23 |
+
save_step_frequency: int
|
24 |
+
"""
|
25 |
+
self.model = model
|
26 |
+
self.checkpoints_dir = checkpoints_dir
|
27 |
+
self.save_step_frequency = save_step_frequency
|
28 |
+
os.makedirs(self.checkpoints_dir, exist_ok=True)
|
29 |
+
|
30 |
+
@rank_zero_only
|
31 |
+
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
|
32 |
+
r"""Save checkpoint."""
|
33 |
+
global_step = trainer.global_step
|
34 |
+
|
35 |
+
if global_step % self.save_step_frequency == 0:
|
36 |
+
|
37 |
+
checkpoint_path = os.path.join(
|
38 |
+
self.checkpoints_dir, "step={}.pth".format(global_step)
|
39 |
+
)
|
40 |
+
|
41 |
+
checkpoint = {'step': global_step, 'model': self.model.state_dict()}
|
42 |
+
|
43 |
+
torch.save(checkpoint, checkpoint_path)
|
44 |
+
logging.info("Save checkpoint to {}".format(checkpoint_path))
|
bytesep/callbacks/instruments_callbacks.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import List, NoReturn
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import numpy as np
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch.nn as nn
|
10 |
+
from pytorch_lightning.utilities import rank_zero_only
|
11 |
+
|
12 |
+
from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
|
13 |
+
from bytesep.inference import Separator
|
14 |
+
from bytesep.utils import StatisticsContainer, calculate_sdr, read_yaml
|
15 |
+
|
16 |
+
|
17 |
+
def get_instruments_callbacks(
|
18 |
+
config_yaml: str,
|
19 |
+
workspace: str,
|
20 |
+
checkpoints_dir: str,
|
21 |
+
statistics_path: str,
|
22 |
+
logger: pl.loggers.TensorBoardLogger,
|
23 |
+
model: nn.Module,
|
24 |
+
evaluate_device: str,
|
25 |
+
) -> List[pl.Callback]:
|
26 |
+
"""Get Voicebank-Demand callbacks of a config yaml.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
config_yaml: str
|
30 |
+
workspace: str
|
31 |
+
checkpoints_dir: str, directory to save checkpoints
|
32 |
+
statistics_dir: str, directory to save statistics
|
33 |
+
logger: pl.loggers.TensorBoardLogger
|
34 |
+
model: nn.Module
|
35 |
+
evaluate_device: str
|
36 |
+
|
37 |
+
Return:
|
38 |
+
callbacks: List[pl.Callback]
|
39 |
+
"""
|
40 |
+
configs = read_yaml(config_yaml)
|
41 |
+
task_name = configs['task_name']
|
42 |
+
target_source_types = configs['train']['target_source_types']
|
43 |
+
input_channels = configs['train']['channels']
|
44 |
+
mono = True if input_channels == 1 else False
|
45 |
+
test_audios_dir = os.path.join(workspace, "evaluation_audios", task_name, "test")
|
46 |
+
sample_rate = configs['train']['sample_rate']
|
47 |
+
evaluate_step_frequency = configs['train']['evaluate_step_frequency']
|
48 |
+
save_step_frequency = configs['train']['save_step_frequency']
|
49 |
+
test_batch_size = configs['evaluate']['batch_size']
|
50 |
+
test_segment_seconds = configs['evaluate']['segment_seconds']
|
51 |
+
|
52 |
+
test_segment_samples = int(test_segment_seconds * sample_rate)
|
53 |
+
assert len(target_source_types) == 1
|
54 |
+
target_source_type = target_source_types[0]
|
55 |
+
|
56 |
+
# save checkpoint callback
|
57 |
+
save_checkpoints_callback = SaveCheckpointsCallback(
|
58 |
+
model=model,
|
59 |
+
checkpoints_dir=checkpoints_dir,
|
60 |
+
save_step_frequency=save_step_frequency,
|
61 |
+
)
|
62 |
+
|
63 |
+
# statistics container
|
64 |
+
statistics_container = StatisticsContainer(statistics_path)
|
65 |
+
|
66 |
+
# evaluation callback
|
67 |
+
evaluate_test_callback = EvaluationCallback(
|
68 |
+
model=model,
|
69 |
+
target_source_type=target_source_type,
|
70 |
+
input_channels=input_channels,
|
71 |
+
sample_rate=sample_rate,
|
72 |
+
mono=mono,
|
73 |
+
evaluation_audios_dir=test_audios_dir,
|
74 |
+
segment_samples=test_segment_samples,
|
75 |
+
batch_size=test_batch_size,
|
76 |
+
device=evaluate_device,
|
77 |
+
evaluate_step_frequency=evaluate_step_frequency,
|
78 |
+
logger=logger,
|
79 |
+
statistics_container=statistics_container,
|
80 |
+
)
|
81 |
+
|
82 |
+
callbacks = [save_checkpoints_callback, evaluate_test_callback]
|
83 |
+
# callbacks = [save_checkpoints_callback]
|
84 |
+
|
85 |
+
return callbacks
|
86 |
+
|
87 |
+
|
88 |
+
class EvaluationCallback(pl.Callback):
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
model: nn.Module,
|
92 |
+
input_channels: int,
|
93 |
+
evaluation_audios_dir: str,
|
94 |
+
target_source_type: str,
|
95 |
+
sample_rate: int,
|
96 |
+
mono: bool,
|
97 |
+
segment_samples: int,
|
98 |
+
batch_size: int,
|
99 |
+
device: str,
|
100 |
+
evaluate_step_frequency: int,
|
101 |
+
logger: pl.loggers.TensorBoardLogger,
|
102 |
+
statistics_container: StatisticsContainer,
|
103 |
+
):
|
104 |
+
r"""Callback to evaluate every #save_step_frequency steps.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
model: nn.Module
|
108 |
+
input_channels: int
|
109 |
+
evaluation_audios_dir: str, directory containing audios for evaluation
|
110 |
+
target_source_type: str, e.g., 'violin'
|
111 |
+
sample_rate: int
|
112 |
+
mono: bool
|
113 |
+
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
|
114 |
+
batch_size, int, e.g., 12
|
115 |
+
device: str, e.g., 'cuda'
|
116 |
+
evaluate_step_frequency: int, evaluate every #save_step_frequency steps
|
117 |
+
logger: pl.loggers.TensorBoardLogger
|
118 |
+
statistics_container: StatisticsContainer
|
119 |
+
"""
|
120 |
+
self.model = model
|
121 |
+
self.target_source_type = target_source_type
|
122 |
+
self.sample_rate = sample_rate
|
123 |
+
self.mono = mono
|
124 |
+
self.segment_samples = segment_samples
|
125 |
+
self.evaluate_step_frequency = evaluate_step_frequency
|
126 |
+
self.logger = logger
|
127 |
+
self.statistics_container = statistics_container
|
128 |
+
|
129 |
+
self.evaluation_audios_dir = evaluation_audios_dir
|
130 |
+
|
131 |
+
# separator
|
132 |
+
self.separator = Separator(model, self.segment_samples, batch_size, device)
|
133 |
+
|
134 |
+
@rank_zero_only
|
135 |
+
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
|
136 |
+
r"""Evaluate losses on a few mini-batches. Losses are only used for
|
137 |
+
observing training, and are not final F1 metrics.
|
138 |
+
"""
|
139 |
+
|
140 |
+
global_step = trainer.global_step
|
141 |
+
|
142 |
+
if global_step % self.evaluate_step_frequency == 0:
|
143 |
+
|
144 |
+
mixture_audios_dir = os.path.join(self.evaluation_audios_dir, 'mixture')
|
145 |
+
clean_audios_dir = os.path.join(
|
146 |
+
self.evaluation_audios_dir, self.target_source_type
|
147 |
+
)
|
148 |
+
|
149 |
+
audio_names = sorted(os.listdir(mixture_audios_dir))
|
150 |
+
|
151 |
+
error_str = "Directory {} does not contain audios for evaluation!".format(
|
152 |
+
self.evaluation_audios_dir
|
153 |
+
)
|
154 |
+
assert len(audio_names) > 0, error_str
|
155 |
+
|
156 |
+
logging.info("--- Step {} ---".format(global_step))
|
157 |
+
logging.info("Total {} pieces for evaluation:".format(len(audio_names)))
|
158 |
+
|
159 |
+
eval_time = time.time()
|
160 |
+
|
161 |
+
sdrs = []
|
162 |
+
|
163 |
+
for n, audio_name in enumerate(audio_names):
|
164 |
+
|
165 |
+
# Load audio.
|
166 |
+
mixture_path = os.path.join(mixture_audios_dir, audio_name)
|
167 |
+
clean_path = os.path.join(clean_audios_dir, audio_name)
|
168 |
+
|
169 |
+
mixture, origin_fs = librosa.core.load(
|
170 |
+
mixture_path, sr=self.sample_rate, mono=self.mono
|
171 |
+
)
|
172 |
+
|
173 |
+
# Target
|
174 |
+
clean, origin_fs = librosa.core.load(
|
175 |
+
clean_path, sr=self.sample_rate, mono=self.mono
|
176 |
+
)
|
177 |
+
|
178 |
+
if mixture.ndim == 1:
|
179 |
+
mixture = mixture[None, :]
|
180 |
+
# (channels_num, audio_length)
|
181 |
+
|
182 |
+
input_dict = {'waveform': mixture}
|
183 |
+
|
184 |
+
# separate
|
185 |
+
sep_wav = self.separator.separate(input_dict)
|
186 |
+
# (channels_num, audio_length)
|
187 |
+
|
188 |
+
sdr = calculate_sdr(ref=clean, est=sep_wav)
|
189 |
+
|
190 |
+
print("{} SDR: {:.3f}".format(audio_name, sdr))
|
191 |
+
sdrs.append(sdr)
|
192 |
+
|
193 |
+
logging.info("-----------------------------")
|
194 |
+
logging.info('Avg SDR: {:.3f}'.format(np.mean(sdrs)))
|
195 |
+
|
196 |
+
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
|
197 |
+
|
198 |
+
statistics = {"sdr": np.mean(sdrs)}
|
199 |
+
self.statistics_container.append(global_step, statistics, 'test')
|
200 |
+
self.statistics_container.dump()
|
bytesep/callbacks/musdb18.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import Dict, List, NoReturn
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import musdb
|
8 |
+
import museval
|
9 |
+
import numpy as np
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
import torch.nn as nn
|
12 |
+
from pytorch_lightning.utilities import rank_zero_only
|
13 |
+
|
14 |
+
from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
|
15 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.musdb18 import preprocess_audio
|
16 |
+
from bytesep.inference import Separator
|
17 |
+
from bytesep.utils import StatisticsContainer, read_yaml
|
18 |
+
|
19 |
+
|
20 |
+
def get_musdb18_callbacks(
|
21 |
+
config_yaml: str,
|
22 |
+
workspace: str,
|
23 |
+
checkpoints_dir: str,
|
24 |
+
statistics_path: str,
|
25 |
+
logger: pl.loggers.TensorBoardLogger,
|
26 |
+
model: nn.Module,
|
27 |
+
evaluate_device: str,
|
28 |
+
) -> List[pl.Callback]:
|
29 |
+
r"""Get MUSDB18 callbacks of a config yaml.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
config_yaml: str
|
33 |
+
workspace: str
|
34 |
+
checkpoints_dir: str, directory to save checkpoints
|
35 |
+
statistics_dir: str, directory to save statistics
|
36 |
+
logger: pl.loggers.TensorBoardLogger
|
37 |
+
model: nn.Module
|
38 |
+
evaluate_device: str
|
39 |
+
|
40 |
+
Return:
|
41 |
+
callbacks: List[pl.Callback]
|
42 |
+
"""
|
43 |
+
configs = read_yaml(config_yaml)
|
44 |
+
task_name = configs['task_name']
|
45 |
+
evaluation_callback = configs['train']['evaluation_callback']
|
46 |
+
target_source_types = configs['train']['target_source_types']
|
47 |
+
input_channels = configs['train']['channels']
|
48 |
+
evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name)
|
49 |
+
test_segment_seconds = configs['evaluate']['segment_seconds']
|
50 |
+
sample_rate = configs['train']['sample_rate']
|
51 |
+
test_segment_samples = int(test_segment_seconds * sample_rate)
|
52 |
+
test_batch_size = configs['evaluate']['batch_size']
|
53 |
+
|
54 |
+
evaluate_step_frequency = configs['train']['evaluate_step_frequency']
|
55 |
+
save_step_frequency = configs['train']['save_step_frequency']
|
56 |
+
|
57 |
+
# save checkpoint callback
|
58 |
+
save_checkpoints_callback = SaveCheckpointsCallback(
|
59 |
+
model=model,
|
60 |
+
checkpoints_dir=checkpoints_dir,
|
61 |
+
save_step_frequency=save_step_frequency,
|
62 |
+
)
|
63 |
+
|
64 |
+
# evaluation callback
|
65 |
+
EvaluationCallback = _get_evaluation_callback_class(evaluation_callback)
|
66 |
+
|
67 |
+
# statistics container
|
68 |
+
statistics_container = StatisticsContainer(statistics_path)
|
69 |
+
|
70 |
+
# evaluation callback
|
71 |
+
evaluate_train_callback = EvaluationCallback(
|
72 |
+
dataset_dir=evaluation_audios_dir,
|
73 |
+
model=model,
|
74 |
+
target_source_types=target_source_types,
|
75 |
+
input_channels=input_channels,
|
76 |
+
sample_rate=sample_rate,
|
77 |
+
split='train',
|
78 |
+
segment_samples=test_segment_samples,
|
79 |
+
batch_size=test_batch_size,
|
80 |
+
device=evaluate_device,
|
81 |
+
evaluate_step_frequency=evaluate_step_frequency,
|
82 |
+
logger=logger,
|
83 |
+
statistics_container=statistics_container,
|
84 |
+
)
|
85 |
+
|
86 |
+
evaluate_test_callback = EvaluationCallback(
|
87 |
+
dataset_dir=evaluation_audios_dir,
|
88 |
+
model=model,
|
89 |
+
target_source_types=target_source_types,
|
90 |
+
input_channels=input_channels,
|
91 |
+
sample_rate=sample_rate,
|
92 |
+
split='test',
|
93 |
+
segment_samples=test_segment_samples,
|
94 |
+
batch_size=test_batch_size,
|
95 |
+
device=evaluate_device,
|
96 |
+
evaluate_step_frequency=evaluate_step_frequency,
|
97 |
+
logger=logger,
|
98 |
+
statistics_container=statistics_container,
|
99 |
+
)
|
100 |
+
|
101 |
+
# callbacks = [save_checkpoints_callback, evaluate_train_callback, evaluate_test_callback]
|
102 |
+
callbacks = [save_checkpoints_callback, evaluate_test_callback]
|
103 |
+
|
104 |
+
return callbacks
|
105 |
+
|
106 |
+
|
107 |
+
def _get_evaluation_callback_class(evaluation_callback) -> pl.Callback:
|
108 |
+
r"""Get evaluation callback class."""
|
109 |
+
if evaluation_callback == "Musdb18EvaluationCallback":
|
110 |
+
return Musdb18EvaluationCallback
|
111 |
+
|
112 |
+
if evaluation_callback == 'Musdb18ConditionalEvaluationCallback':
|
113 |
+
return Musdb18ConditionalEvaluationCallback
|
114 |
+
|
115 |
+
else:
|
116 |
+
raise NotImplementedError
|
117 |
+
|
118 |
+
|
119 |
+
class Musdb18EvaluationCallback(pl.Callback):
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
dataset_dir: str,
|
123 |
+
model: nn.Module,
|
124 |
+
target_source_types: str,
|
125 |
+
input_channels: int,
|
126 |
+
split: str,
|
127 |
+
sample_rate: int,
|
128 |
+
segment_samples: int,
|
129 |
+
batch_size: int,
|
130 |
+
device: str,
|
131 |
+
evaluate_step_frequency: int,
|
132 |
+
logger: pl.loggers.TensorBoardLogger,
|
133 |
+
statistics_container: StatisticsContainer,
|
134 |
+
):
|
135 |
+
r"""Callback to evaluate every #save_step_frequency steps.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
dataset_dir: str
|
139 |
+
model: nn.Module
|
140 |
+
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
|
141 |
+
input_channels: int
|
142 |
+
split: 'train' | 'test'
|
143 |
+
sample_rate: int
|
144 |
+
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
|
145 |
+
batch_size, int, e.g., 12
|
146 |
+
device: str, e.g., 'cuda'
|
147 |
+
evaluate_step_frequency: int, evaluate every #save_step_frequency steps
|
148 |
+
logger: object
|
149 |
+
statistics_container: StatisticsContainer
|
150 |
+
"""
|
151 |
+
self.model = model
|
152 |
+
self.target_source_types = target_source_types
|
153 |
+
self.input_channels = input_channels
|
154 |
+
self.sample_rate = sample_rate
|
155 |
+
self.split = split
|
156 |
+
self.segment_samples = segment_samples
|
157 |
+
self.evaluate_step_frequency = evaluate_step_frequency
|
158 |
+
self.logger = logger
|
159 |
+
self.statistics_container = statistics_container
|
160 |
+
self.mono = input_channels == 1
|
161 |
+
self.resample_type = "kaiser_fast"
|
162 |
+
|
163 |
+
self.mus = musdb.DB(root=dataset_dir, subsets=[split])
|
164 |
+
|
165 |
+
error_msg = "The directory {} is empty!".format(dataset_dir)
|
166 |
+
assert len(self.mus) > 0, error_msg
|
167 |
+
|
168 |
+
# separator
|
169 |
+
self.separator = Separator(model, self.segment_samples, batch_size, device)
|
170 |
+
|
171 |
+
@rank_zero_only
|
172 |
+
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
|
173 |
+
r"""Evaluate separation SDRs of audio recordings."""
|
174 |
+
global_step = trainer.global_step
|
175 |
+
|
176 |
+
if global_step % self.evaluate_step_frequency == 0:
|
177 |
+
|
178 |
+
sdr_dict = {}
|
179 |
+
|
180 |
+
logging.info("--- Step {} ---".format(global_step))
|
181 |
+
logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))
|
182 |
+
|
183 |
+
eval_time = time.time()
|
184 |
+
|
185 |
+
for track in self.mus.tracks:
|
186 |
+
|
187 |
+
audio_name = track.name
|
188 |
+
|
189 |
+
# Get waveform of mixture.
|
190 |
+
mixture = track.audio.T
|
191 |
+
# (channels_num, audio_samples)
|
192 |
+
|
193 |
+
mixture = preprocess_audio(
|
194 |
+
audio=mixture,
|
195 |
+
mono=self.mono,
|
196 |
+
origin_sr=track.rate,
|
197 |
+
sr=self.sample_rate,
|
198 |
+
resample_type=self.resample_type,
|
199 |
+
)
|
200 |
+
# (channels_num, audio_samples)
|
201 |
+
|
202 |
+
target_dict = {}
|
203 |
+
sdr_dict[audio_name] = {}
|
204 |
+
|
205 |
+
# Get waveform of all target source types.
|
206 |
+
for j, source_type in enumerate(self.target_source_types):
|
207 |
+
# E.g., ['vocals', 'bass', ...]
|
208 |
+
|
209 |
+
audio = track.targets[source_type].audio.T
|
210 |
+
|
211 |
+
audio = preprocess_audio(
|
212 |
+
audio=audio,
|
213 |
+
mono=self.mono,
|
214 |
+
origin_sr=track.rate,
|
215 |
+
sr=self.sample_rate,
|
216 |
+
resample_type=self.resample_type,
|
217 |
+
)
|
218 |
+
# (channels_num, audio_samples)
|
219 |
+
|
220 |
+
target_dict[source_type] = audio
|
221 |
+
# (channels_num, audio_samples)
|
222 |
+
|
223 |
+
# Separate.
|
224 |
+
input_dict = {'waveform': mixture}
|
225 |
+
|
226 |
+
sep_wavs = self.separator.separate(input_dict)
|
227 |
+
# sep_wavs: (target_sources_num * channels_num, audio_samples)
|
228 |
+
|
229 |
+
# Post process separation results.
|
230 |
+
sep_wavs = preprocess_audio(
|
231 |
+
audio=sep_wavs,
|
232 |
+
mono=self.mono,
|
233 |
+
origin_sr=self.sample_rate,
|
234 |
+
sr=track.rate,
|
235 |
+
resample_type=self.resample_type,
|
236 |
+
)
|
237 |
+
# sep_wavs: (target_sources_num * channels_num, audio_samples)
|
238 |
+
|
239 |
+
sep_wavs = librosa.util.fix_length(
|
240 |
+
sep_wavs, size=mixture.shape[1], axis=1
|
241 |
+
)
|
242 |
+
# sep_wavs: (target_sources_num * channels_num, audio_samples)
|
243 |
+
|
244 |
+
sep_wav_dict = get_separated_wavs_from_simo_output(
|
245 |
+
sep_wavs, self.input_channels, self.target_source_types
|
246 |
+
)
|
247 |
+
# output_dict: dict, e.g., {
|
248 |
+
# 'vocals': (channels_num, audio_samples),
|
249 |
+
# 'bass': (channels_num, audio_samples),
|
250 |
+
# ...,
|
251 |
+
# }
|
252 |
+
|
253 |
+
# Evaluate for all target source types.
|
254 |
+
for source_type in self.target_source_types:
|
255 |
+
# E.g., ['vocals', 'bass', ...]
|
256 |
+
|
257 |
+
# Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan).
|
258 |
+
(sdrs, _, _, _) = museval.evaluate(
|
259 |
+
[target_dict[source_type].T], [sep_wav_dict[source_type].T]
|
260 |
+
)
|
261 |
+
|
262 |
+
sdr = np.nanmedian(sdrs)
|
263 |
+
sdr_dict[audio_name][source_type] = sdr
|
264 |
+
|
265 |
+
logging.info(
|
266 |
+
"{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
|
267 |
+
)
|
268 |
+
|
269 |
+
logging.info("-----------------------------")
|
270 |
+
median_sdr_dict = {}
|
271 |
+
|
272 |
+
# Calculate median SDRs of all songs.
|
273 |
+
for source_type in self.target_source_types:
|
274 |
+
# E.g., ['vocals', 'bass', ...]
|
275 |
+
|
276 |
+
median_sdr = np.median(
|
277 |
+
[
|
278 |
+
sdr_dict[audio_name][source_type]
|
279 |
+
for audio_name in sdr_dict.keys()
|
280 |
+
]
|
281 |
+
)
|
282 |
+
|
283 |
+
median_sdr_dict[source_type] = median_sdr
|
284 |
+
|
285 |
+
logging.info(
|
286 |
+
"Step: {}, {}, Median SDR: {:.3f}".format(
|
287 |
+
global_step, source_type, median_sdr
|
288 |
+
)
|
289 |
+
)
|
290 |
+
|
291 |
+
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
|
292 |
+
|
293 |
+
statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
|
294 |
+
self.statistics_container.append(global_step, statistics, self.split)
|
295 |
+
self.statistics_container.dump()
|
296 |
+
|
297 |
+
|
298 |
+
def get_separated_wavs_from_simo_output(x, input_channels, target_source_types) -> Dict:
|
299 |
+
r"""Get separated waveforms of target sources from a single input multiple
|
300 |
+
output (SIMO) system.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
x: (target_sources_num * channels_num, audio_samples)
|
304 |
+
input_channels: int
|
305 |
+
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
output_dict: dict, e.g., {
|
309 |
+
'vocals': (channels_num, audio_samples),
|
310 |
+
'bass': (channels_num, audio_samples),
|
311 |
+
...,
|
312 |
+
}
|
313 |
+
"""
|
314 |
+
output_dict = {}
|
315 |
+
|
316 |
+
for j, source_type in enumerate(target_source_types):
|
317 |
+
output_dict[source_type] = x[j * input_channels : (j + 1) * input_channels]
|
318 |
+
|
319 |
+
return output_dict
|
320 |
+
|
321 |
+
|
322 |
+
class Musdb18ConditionalEvaluationCallback(pl.Callback):
|
323 |
+
def __init__(
|
324 |
+
self,
|
325 |
+
dataset_dir: str,
|
326 |
+
model: nn.Module,
|
327 |
+
target_source_types: str,
|
328 |
+
input_channels: int,
|
329 |
+
split: str,
|
330 |
+
sample_rate: int,
|
331 |
+
segment_samples: int,
|
332 |
+
batch_size: int,
|
333 |
+
device: str,
|
334 |
+
evaluate_step_frequency: int,
|
335 |
+
logger: pl.loggers.TensorBoardLogger,
|
336 |
+
statistics_container: StatisticsContainer,
|
337 |
+
):
|
338 |
+
r"""Callback to evaluate every #save_step_frequency steps.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
dataset_dir: str
|
342 |
+
model: nn.Module
|
343 |
+
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
|
344 |
+
input_channels: int
|
345 |
+
split: 'train' | 'test'
|
346 |
+
sample_rate: int
|
347 |
+
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
|
348 |
+
batch_size, int, e.g., 12
|
349 |
+
device: str, e.g., 'cuda'
|
350 |
+
evaluate_step_frequency: int, evaluate every #save_step_frequency steps
|
351 |
+
logger: object
|
352 |
+
statistics_container: StatisticsContainer
|
353 |
+
"""
|
354 |
+
self.model = model
|
355 |
+
self.target_source_types = target_source_types
|
356 |
+
self.input_channels = input_channels
|
357 |
+
self.sample_rate = sample_rate
|
358 |
+
self.split = split
|
359 |
+
self.segment_samples = segment_samples
|
360 |
+
self.evaluate_step_frequency = evaluate_step_frequency
|
361 |
+
self.logger = logger
|
362 |
+
self.statistics_container = statistics_container
|
363 |
+
self.mono = input_channels == 1
|
364 |
+
self.resample_type = "kaiser_fast"
|
365 |
+
|
366 |
+
self.mus = musdb.DB(root=dataset_dir, subsets=[split])
|
367 |
+
|
368 |
+
error_msg = "The directory {} is empty!".format(dataset_dir)
|
369 |
+
assert len(self.mus) > 0, error_msg
|
370 |
+
|
371 |
+
# separator
|
372 |
+
self.separator = Separator(model, self.segment_samples, batch_size, device)
|
373 |
+
|
374 |
+
@rank_zero_only
|
375 |
+
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
|
376 |
+
r"""Evaluate separation SDRs of audio recordings."""
|
377 |
+
global_step = trainer.global_step
|
378 |
+
|
379 |
+
if global_step % self.evaluate_step_frequency == 0:
|
380 |
+
|
381 |
+
sdr_dict = {}
|
382 |
+
|
383 |
+
logging.info("--- Step {} ---".format(global_step))
|
384 |
+
logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))
|
385 |
+
|
386 |
+
eval_time = time.time()
|
387 |
+
|
388 |
+
for track in self.mus.tracks:
|
389 |
+
|
390 |
+
audio_name = track.name
|
391 |
+
|
392 |
+
# Get waveform of mixture.
|
393 |
+
mixture = track.audio.T
|
394 |
+
# (channels_num, audio_samples)
|
395 |
+
|
396 |
+
mixture = preprocess_audio(
|
397 |
+
audio=mixture,
|
398 |
+
mono=self.mono,
|
399 |
+
origin_sr=track.rate,
|
400 |
+
sr=self.sample_rate,
|
401 |
+
resample_type=self.resample_type,
|
402 |
+
)
|
403 |
+
# (channels_num, audio_samples)
|
404 |
+
|
405 |
+
target_dict = {}
|
406 |
+
sdr_dict[audio_name] = {}
|
407 |
+
|
408 |
+
# Get waveform of all target source types.
|
409 |
+
for j, source_type in enumerate(self.target_source_types):
|
410 |
+
# E.g., ['vocals', 'bass', ...]
|
411 |
+
|
412 |
+
audio = track.targets[source_type].audio.T
|
413 |
+
|
414 |
+
audio = preprocess_audio(
|
415 |
+
audio=audio,
|
416 |
+
mono=self.mono,
|
417 |
+
origin_sr=track.rate,
|
418 |
+
sr=self.sample_rate,
|
419 |
+
resample_type=self.resample_type,
|
420 |
+
)
|
421 |
+
# (channels_num, audio_samples)
|
422 |
+
|
423 |
+
target_dict[source_type] = audio
|
424 |
+
# (channels_num, audio_samples)
|
425 |
+
|
426 |
+
condition = np.zeros(len(self.target_source_types))
|
427 |
+
condition[j] = 1
|
428 |
+
|
429 |
+
input_dict = {'waveform': mixture, 'condition': condition}
|
430 |
+
|
431 |
+
sep_wav = self.separator.separate(input_dict)
|
432 |
+
# sep_wav: (channels_num, audio_samples)
|
433 |
+
|
434 |
+
sep_wav = preprocess_audio(
|
435 |
+
audio=sep_wav,
|
436 |
+
mono=self.mono,
|
437 |
+
origin_sr=self.sample_rate,
|
438 |
+
sr=track.rate,
|
439 |
+
resample_type=self.resample_type,
|
440 |
+
)
|
441 |
+
# sep_wav: (channels_num, audio_samples)
|
442 |
+
|
443 |
+
sep_wav = librosa.util.fix_length(
|
444 |
+
sep_wav, size=mixture.shape[1], axis=1
|
445 |
+
)
|
446 |
+
# sep_wav: (target_sources_num * channels_num, audio_samples)
|
447 |
+
|
448 |
+
# Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan)
|
449 |
+
(sdrs, _, _, _) = museval.evaluate(
|
450 |
+
[target_dict[source_type].T], [sep_wav.T]
|
451 |
+
)
|
452 |
+
|
453 |
+
sdr = np.nanmedian(sdrs)
|
454 |
+
sdr_dict[audio_name][source_type] = sdr
|
455 |
+
|
456 |
+
logging.info(
|
457 |
+
"{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
|
458 |
+
)
|
459 |
+
|
460 |
+
logging.info("-----------------------------")
|
461 |
+
median_sdr_dict = {}
|
462 |
+
|
463 |
+
# Calculate median SDRs of all songs.
|
464 |
+
for source_type in self.target_source_types:
|
465 |
+
|
466 |
+
median_sdr = np.median(
|
467 |
+
[
|
468 |
+
sdr_dict[audio_name][source_type]
|
469 |
+
for audio_name in sdr_dict.keys()
|
470 |
+
]
|
471 |
+
)
|
472 |
+
|
473 |
+
median_sdr_dict[source_type] = median_sdr
|
474 |
+
|
475 |
+
logging.info(
|
476 |
+
"Step: {}, {}, Median SDR: {:.3f}".format(
|
477 |
+
global_step, source_type, median_sdr
|
478 |
+
)
|
479 |
+
)
|
480 |
+
|
481 |
+
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
|
482 |
+
|
483 |
+
statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
|
484 |
+
self.statistics_container.append(global_step, statistics, self.split)
|
485 |
+
self.statistics_container.dump()
|
bytesep/callbacks/voicebank_demand.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import List, NoReturn
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import numpy as np
|
8 |
+
import pysepm
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
import torch.nn as nn
|
11 |
+
from pesq import pesq
|
12 |
+
from pytorch_lightning.utilities import rank_zero_only
|
13 |
+
|
14 |
+
from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
|
15 |
+
from bytesep.inference import Separator
|
16 |
+
from bytesep.utils import StatisticsContainer, read_yaml
|
17 |
+
|
18 |
+
|
19 |
+
def get_voicebank_demand_callbacks(
|
20 |
+
config_yaml: str,
|
21 |
+
workspace: str,
|
22 |
+
checkpoints_dir: str,
|
23 |
+
statistics_path: str,
|
24 |
+
logger: pl.loggers.TensorBoardLogger,
|
25 |
+
model: nn.Module,
|
26 |
+
evaluate_device: str,
|
27 |
+
) -> List[pl.Callback]:
|
28 |
+
"""Get Voicebank-Demand callbacks of a config yaml.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
config_yaml: str
|
32 |
+
workspace: str
|
33 |
+
checkpoints_dir: str, directory to save checkpoints
|
34 |
+
statistics_dir: str, directory to save statistics
|
35 |
+
logger: pl.loggers.TensorBoardLogger
|
36 |
+
model: nn.Module
|
37 |
+
evaluate_device: str
|
38 |
+
|
39 |
+
Return:
|
40 |
+
callbacks: List[pl.Callback]
|
41 |
+
"""
|
42 |
+
configs = read_yaml(config_yaml)
|
43 |
+
task_name = configs['task_name']
|
44 |
+
target_source_types = configs['train']['target_source_types']
|
45 |
+
input_channels = configs['train']['channels']
|
46 |
+
evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name)
|
47 |
+
sample_rate = configs['train']['sample_rate']
|
48 |
+
evaluate_step_frequency = configs['train']['evaluate_step_frequency']
|
49 |
+
save_step_frequency = configs['train']['save_step_frequency']
|
50 |
+
test_batch_size = configs['evaluate']['batch_size']
|
51 |
+
test_segment_seconds = configs['evaluate']['segment_seconds']
|
52 |
+
|
53 |
+
test_segment_samples = int(test_segment_seconds * sample_rate)
|
54 |
+
assert len(target_source_types) == 1
|
55 |
+
target_source_type = target_source_types[0]
|
56 |
+
assert target_source_type == 'speech'
|
57 |
+
|
58 |
+
# save checkpoint callback
|
59 |
+
save_checkpoints_callback = SaveCheckpointsCallback(
|
60 |
+
model=model,
|
61 |
+
checkpoints_dir=checkpoints_dir,
|
62 |
+
save_step_frequency=save_step_frequency,
|
63 |
+
)
|
64 |
+
|
65 |
+
# statistics container
|
66 |
+
statistics_container = StatisticsContainer(statistics_path)
|
67 |
+
|
68 |
+
# evaluation callback
|
69 |
+
evaluate_test_callback = EvaluationCallback(
|
70 |
+
model=model,
|
71 |
+
input_channels=input_channels,
|
72 |
+
sample_rate=sample_rate,
|
73 |
+
evaluation_audios_dir=evaluation_audios_dir,
|
74 |
+
segment_samples=test_segment_samples,
|
75 |
+
batch_size=test_batch_size,
|
76 |
+
device=evaluate_device,
|
77 |
+
evaluate_step_frequency=evaluate_step_frequency,
|
78 |
+
logger=logger,
|
79 |
+
statistics_container=statistics_container,
|
80 |
+
)
|
81 |
+
|
82 |
+
callbacks = [save_checkpoints_callback, evaluate_test_callback]
|
83 |
+
|
84 |
+
return callbacks
|
85 |
+
|
86 |
+
|
87 |
+
class EvaluationCallback(pl.Callback):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
model: nn.Module,
|
91 |
+
input_channels: int,
|
92 |
+
evaluation_audios_dir,
|
93 |
+
sample_rate: int,
|
94 |
+
segment_samples: int,
|
95 |
+
batch_size: int,
|
96 |
+
device: str,
|
97 |
+
evaluate_step_frequency: int,
|
98 |
+
logger: pl.loggers.TensorBoardLogger,
|
99 |
+
statistics_container: StatisticsContainer,
|
100 |
+
):
|
101 |
+
r"""Callback to evaluate every #save_step_frequency steps.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
model: nn.Module
|
105 |
+
input_channels: int
|
106 |
+
evaluation_audios_dir: str, directory containing audios for evaluation
|
107 |
+
sample_rate: int
|
108 |
+
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
|
109 |
+
batch_size, int, e.g., 12
|
110 |
+
device: str, e.g., 'cuda'
|
111 |
+
evaluate_step_frequency: int, evaluate every #save_step_frequency steps
|
112 |
+
logger: pl.loggers.TensorBoardLogger
|
113 |
+
statistics_container: StatisticsContainer
|
114 |
+
"""
|
115 |
+
self.model = model
|
116 |
+
self.mono = True
|
117 |
+
self.sample_rate = sample_rate
|
118 |
+
self.segment_samples = segment_samples
|
119 |
+
self.evaluate_step_frequency = evaluate_step_frequency
|
120 |
+
self.logger = logger
|
121 |
+
self.statistics_container = statistics_container
|
122 |
+
|
123 |
+
self.clean_dir = os.path.join(evaluation_audios_dir, "clean_testset_wav")
|
124 |
+
self.noisy_dir = os.path.join(evaluation_audios_dir, "noisy_testset_wav")
|
125 |
+
|
126 |
+
self.EVALUATION_SAMPLE_RATE = 16000 # Evaluation sample rate of the
|
127 |
+
# Voicebank-Demand task.
|
128 |
+
|
129 |
+
# separator
|
130 |
+
self.separator = Separator(model, self.segment_samples, batch_size, device)
|
131 |
+
|
132 |
+
@rank_zero_only
|
133 |
+
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
|
134 |
+
r"""Evaluate losses on a few mini-batches. Losses are only used for
|
135 |
+
observing training, and are not final F1 metrics.
|
136 |
+
"""
|
137 |
+
|
138 |
+
global_step = trainer.global_step
|
139 |
+
|
140 |
+
if global_step % self.evaluate_step_frequency == 0:
|
141 |
+
|
142 |
+
audio_names = sorted(
|
143 |
+
[
|
144 |
+
audio_name
|
145 |
+
for audio_name in sorted(os.listdir(self.clean_dir))
|
146 |
+
if audio_name.endswith('.wav')
|
147 |
+
]
|
148 |
+
)
|
149 |
+
|
150 |
+
error_str = "Directory {} does not contain audios for evaluation!".format(
|
151 |
+
self.clean_dir
|
152 |
+
)
|
153 |
+
assert len(audio_names) > 0, error_str
|
154 |
+
|
155 |
+
pesqs, csigs, cbaks, covls, ssnrs = [], [], [], [], []
|
156 |
+
|
157 |
+
logging.info("--- Step {} ---".format(global_step))
|
158 |
+
logging.info("Total {} pieces for evaluation:".format(len(audio_names)))
|
159 |
+
|
160 |
+
eval_time = time.time()
|
161 |
+
|
162 |
+
for n, audio_name in enumerate(audio_names):
|
163 |
+
|
164 |
+
# Load audio.
|
165 |
+
clean_path = os.path.join(self.clean_dir, audio_name)
|
166 |
+
mixture_path = os.path.join(self.noisy_dir, audio_name)
|
167 |
+
|
168 |
+
mixture, _ = librosa.core.load(
|
169 |
+
mixture_path, sr=self.sample_rate, mono=self.mono
|
170 |
+
)
|
171 |
+
|
172 |
+
if mixture.ndim == 1:
|
173 |
+
mixture = mixture[None, :]
|
174 |
+
# (channels_num, audio_length)
|
175 |
+
|
176 |
+
# Separate.
|
177 |
+
input_dict = {'waveform': mixture}
|
178 |
+
|
179 |
+
sep_wav = self.separator.separate(input_dict)
|
180 |
+
# (channels_num, audio_length)
|
181 |
+
|
182 |
+
# Target
|
183 |
+
clean, _ = librosa.core.load(
|
184 |
+
clean_path, sr=self.EVALUATION_SAMPLE_RATE, mono=self.mono
|
185 |
+
)
|
186 |
+
|
187 |
+
# to mono
|
188 |
+
sep_wav = np.squeeze(sep_wav)
|
189 |
+
|
190 |
+
# Resample for evaluation.
|
191 |
+
sep_wav = librosa.resample(
|
192 |
+
sep_wav,
|
193 |
+
orig_sr=self.sample_rate,
|
194 |
+
target_sr=self.EVALUATION_SAMPLE_RATE,
|
195 |
+
)
|
196 |
+
|
197 |
+
sep_wav = librosa.util.fix_length(sep_wav, size=len(clean), axis=0)
|
198 |
+
# (channels, audio_length)
|
199 |
+
|
200 |
+
# Evaluate metrics
|
201 |
+
pesq_ = pesq(self.EVALUATION_SAMPLE_RATE, clean, sep_wav, 'wb')
|
202 |
+
|
203 |
+
(csig, cbak, covl) = pysepm.composite(
|
204 |
+
clean, sep_wav, self.EVALUATION_SAMPLE_RATE
|
205 |
+
)
|
206 |
+
|
207 |
+
ssnr = pysepm.SNRseg(clean, sep_wav, self.EVALUATION_SAMPLE_RATE)
|
208 |
+
|
209 |
+
pesqs.append(pesq_)
|
210 |
+
csigs.append(csig)
|
211 |
+
cbaks.append(cbak)
|
212 |
+
covls.append(covl)
|
213 |
+
ssnrs.append(ssnr)
|
214 |
+
print(
|
215 |
+
'{}, {}, PESQ: {:.3f}, CSIG: {:.3f}, CBAK: {:.3f}, COVL: {:.3f}, SSNR: {:.3f}'.format(
|
216 |
+
n, audio_name, pesq_, csig, cbak, covl, ssnr
|
217 |
+
)
|
218 |
+
)
|
219 |
+
|
220 |
+
logging.info("-----------------------------")
|
221 |
+
logging.info('Avg PESQ: {:.3f}'.format(np.mean(pesqs)))
|
222 |
+
logging.info('Avg CSIG: {:.3f}'.format(np.mean(csigs)))
|
223 |
+
logging.info('Avg CBAK: {:.3f}'.format(np.mean(cbaks)))
|
224 |
+
logging.info('Avg COVL: {:.3f}'.format(np.mean(covls)))
|
225 |
+
logging.info('Avg SSNR: {:.3f}'.format(np.mean(ssnrs)))
|
226 |
+
|
227 |
+
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
|
228 |
+
|
229 |
+
statistics = {"pesq": np.mean(pesqs)}
|
230 |
+
self.statistics_container.append(global_step, statistics, 'test')
|
231 |
+
self.statistics_container.dump()
|
bytesep/data/__init__.py
ADDED
File without changes
|
bytesep/data/augmentors.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from bytesep.utils import db_to_magnitude, get_pitch_shift_factor, magnitude_to_db
|
7 |
+
|
8 |
+
|
9 |
+
class Augmentor:
|
10 |
+
def __init__(self, augmentations: Dict, random_seed=1234):
|
11 |
+
r"""Augmentor for data augmentation of a waveform.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
augmentations: Dict, e.g, {
|
15 |
+
'mixaudio': {'vocals': 2, 'accompaniment': 2}
|
16 |
+
'pitch_shift': {'vocals': 4, 'accompaniment': 4},
|
17 |
+
...,
|
18 |
+
}
|
19 |
+
random_seed: int
|
20 |
+
"""
|
21 |
+
self.augmentations = augmentations
|
22 |
+
self.random_state = np.random.RandomState(random_seed)
|
23 |
+
|
24 |
+
def __call__(self, waveform: np.array, source_type: str) -> np.array:
|
25 |
+
r"""Augment a waveform.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
waveform: (channels_num, audio_samples)
|
29 |
+
source_type: str
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
new_waveform: (channels_num, new_audio_samples)
|
33 |
+
"""
|
34 |
+
if 'pitch_shift' in self.augmentations.keys():
|
35 |
+
waveform = self.pitch_shift(waveform, source_type)
|
36 |
+
|
37 |
+
if 'magnitude_scale' in self.augmentations.keys():
|
38 |
+
waveform = self.magnitude_scale(waveform, source_type)
|
39 |
+
|
40 |
+
if 'swap_channel' in self.augmentations.keys():
|
41 |
+
waveform = self.swap_channel(waveform, source_type)
|
42 |
+
|
43 |
+
if 'flip_axis' in self.augmentations.keys():
|
44 |
+
waveform = self.flip_axis(waveform, source_type)
|
45 |
+
|
46 |
+
return waveform
|
47 |
+
|
48 |
+
def pitch_shift(self, waveform: np.array, source_type: str) -> np.array:
|
49 |
+
r"""Shift the pitch of a waveform. We use resampling for fast pitch
|
50 |
+
shifting, so the speed will also be chaneged. The length of the returned
|
51 |
+
waveform will be changed.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
waveform: (channels_num, audio_samples)
|
55 |
+
source_type: str
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
new_waveform: (channels_num, new_audio_samples)
|
59 |
+
"""
|
60 |
+
|
61 |
+
# maximum pitch shift in semitones
|
62 |
+
max_pitch_shift = self.augmentations['pitch_shift'][source_type]
|
63 |
+
|
64 |
+
if max_pitch_shift == 0: # No pitch shift augmentations.
|
65 |
+
return waveform
|
66 |
+
|
67 |
+
# random pitch shift
|
68 |
+
rand_pitch = self.random_state.uniform(
|
69 |
+
low=-max_pitch_shift, high=max_pitch_shift
|
70 |
+
)
|
71 |
+
|
72 |
+
# We use librosa.resample instead of librosa.effects.pitch_shift
|
73 |
+
# because it is 10x times faster.
|
74 |
+
pitch_shift_factor = get_pitch_shift_factor(rand_pitch)
|
75 |
+
dummy_sample_rate = 10000 # Dummy constant.
|
76 |
+
|
77 |
+
channels_num = waveform.shape[0]
|
78 |
+
|
79 |
+
if channels_num == 1:
|
80 |
+
waveform = np.squeeze(waveform)
|
81 |
+
|
82 |
+
new_waveform = librosa.resample(
|
83 |
+
y=waveform,
|
84 |
+
orig_sr=dummy_sample_rate,
|
85 |
+
target_sr=dummy_sample_rate / pitch_shift_factor,
|
86 |
+
res_type='linear',
|
87 |
+
axis=-1,
|
88 |
+
)
|
89 |
+
|
90 |
+
if channels_num == 1:
|
91 |
+
new_waveform = new_waveform[None, :]
|
92 |
+
|
93 |
+
return new_waveform
|
94 |
+
|
95 |
+
def magnitude_scale(self, waveform: np.array, source_type: str) -> np.array:
|
96 |
+
r"""Scale the magnitude of a waveform.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
waveform: (channels_num, audio_samples)
|
100 |
+
source_type: str
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
new_waveform: (channels_num, audio_samples)
|
104 |
+
"""
|
105 |
+
lower_db = self.augmentations['magnitude_scale'][source_type]['lower_db']
|
106 |
+
higher_db = self.augmentations['magnitude_scale'][source_type]['higher_db']
|
107 |
+
|
108 |
+
if lower_db == 0 and higher_db == 0: # No magnitude scale augmentation.
|
109 |
+
return waveform
|
110 |
+
|
111 |
+
# The magnitude (in dB) of the sample with the maximum value.
|
112 |
+
waveform_db = magnitude_to_db(np.max(np.abs(waveform)))
|
113 |
+
|
114 |
+
new_waveform_db = self.random_state.uniform(
|
115 |
+
waveform_db + lower_db, min(waveform_db + higher_db, 0)
|
116 |
+
)
|
117 |
+
|
118 |
+
relative_db = new_waveform_db - waveform_db
|
119 |
+
|
120 |
+
relative_scale = db_to_magnitude(relative_db)
|
121 |
+
|
122 |
+
new_waveform = waveform * relative_scale
|
123 |
+
|
124 |
+
return new_waveform
|
125 |
+
|
126 |
+
def swap_channel(self, waveform: np.array, source_type: str) -> np.array:
|
127 |
+
r"""Randomly swap channels.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
waveform: (channels_num, audio_samples)
|
131 |
+
source_type: str
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
new_waveform: (channels_num, audio_samples)
|
135 |
+
"""
|
136 |
+
ndim = waveform.shape[0]
|
137 |
+
|
138 |
+
if ndim == 1:
|
139 |
+
return waveform
|
140 |
+
else:
|
141 |
+
random_axes = self.random_state.permutation(ndim)
|
142 |
+
return waveform[random_axes, :]
|
143 |
+
|
144 |
+
def flip_axis(self, waveform: np.array, source_type: str) -> np.array:
|
145 |
+
r"""Randomly flip the waveform along x-axis.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
waveform: (channels_num, audio_samples)
|
149 |
+
source_type: str
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
new_waveform: (channels_num, audio_samples)
|
153 |
+
"""
|
154 |
+
ndim = waveform.shape[0]
|
155 |
+
random_values = self.random_state.choice([-1, 1], size=ndim)
|
156 |
+
|
157 |
+
return waveform * random_values[:, None]
|
bytesep/data/batch_data_preprocessors.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class BasicBatchDataPreprocessor:
|
7 |
+
def __init__(self, target_source_types: List[str]):
|
8 |
+
r"""Batch data preprocessor. Used for preparing mixtures and targets for
|
9 |
+
training. If there are multiple target source types, the waveforms of
|
10 |
+
those sources will be stacked along the channel dimension.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
|
14 |
+
"""
|
15 |
+
self.target_source_types = target_source_types
|
16 |
+
|
17 |
+
def __call__(self, batch_data_dict: Dict) -> List[Dict]:
|
18 |
+
r"""Format waveforms and targets for training.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
batch_data_dict: dict, e.g., {
|
22 |
+
'mixture': (batch_size, channels_num, segment_samples),
|
23 |
+
'vocals': (batch_size, channels_num, segment_samples),
|
24 |
+
'bass': (batch_size, channels_num, segment_samples),
|
25 |
+
...,
|
26 |
+
}
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
input_dict: dict, e.g., {
|
30 |
+
'waveform': (batch_size, channels_num, segment_samples),
|
31 |
+
}
|
32 |
+
output_dict: dict, e.g., {
|
33 |
+
'target': (batch_size, target_sources_num * channels_num, segment_samples)
|
34 |
+
}
|
35 |
+
"""
|
36 |
+
mixtures = batch_data_dict['mixture']
|
37 |
+
# mixtures: (batch_size, channels_num, segment_samples)
|
38 |
+
|
39 |
+
# Concatenate waveforms of multiple targets along the channel axis.
|
40 |
+
targets = torch.cat(
|
41 |
+
[batch_data_dict[source_type] for source_type in self.target_source_types],
|
42 |
+
dim=1,
|
43 |
+
)
|
44 |
+
# targets: (batch_size, target_sources_num * channels_num, segment_samples)
|
45 |
+
|
46 |
+
input_dict = {'waveform': mixtures}
|
47 |
+
target_dict = {'waveform': targets}
|
48 |
+
|
49 |
+
return input_dict, target_dict
|
50 |
+
|
51 |
+
|
52 |
+
class ConditionalSisoBatchDataPreprocessor:
|
53 |
+
def __init__(self, target_source_types: List[str]):
|
54 |
+
r"""Conditional single input single output (SISO) batch data
|
55 |
+
preprocessor. Select one target source from several target sources as
|
56 |
+
training target and prepare the corresponding conditional vector.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
|
60 |
+
"""
|
61 |
+
self.target_source_types = target_source_types
|
62 |
+
|
63 |
+
def __call__(self, batch_data_dict: Dict) -> List[Dict]:
|
64 |
+
r"""Format waveforms and targets for training.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
batch_data_dict: dict, e.g., {
|
68 |
+
'mixture': (batch_size, channels_num, segment_samples),
|
69 |
+
'vocals': (batch_size, channels_num, segment_samples),
|
70 |
+
'bass': (batch_size, channels_num, segment_samples),
|
71 |
+
...,
|
72 |
+
}
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
input_dict: dict, e.g., {
|
76 |
+
'waveform': (batch_size, channels_num, segment_samples),
|
77 |
+
'condition': (batch_size, target_sources_num),
|
78 |
+
}
|
79 |
+
output_dict: dict, e.g., {
|
80 |
+
'target': (batch_size, channels_num, segment_samples)
|
81 |
+
}
|
82 |
+
"""
|
83 |
+
|
84 |
+
batch_size = len(batch_data_dict['mixture'])
|
85 |
+
target_sources_num = len(self.target_source_types)
|
86 |
+
|
87 |
+
assert (
|
88 |
+
batch_size % target_sources_num == 0
|
89 |
+
), "Batch size should be \
|
90 |
+
evenly divided by target sources number."
|
91 |
+
|
92 |
+
mixtures = batch_data_dict['mixture']
|
93 |
+
# mixtures: (batch_size, channels_num, segment_samples)
|
94 |
+
|
95 |
+
conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device)
|
96 |
+
# conditions: (batch_size, target_sources_num)
|
97 |
+
|
98 |
+
targets = []
|
99 |
+
|
100 |
+
for n in range(batch_size):
|
101 |
+
|
102 |
+
k = n % target_sources_num # source class index
|
103 |
+
source_type = self.target_source_types[k]
|
104 |
+
|
105 |
+
targets.append(batch_data_dict[source_type][n])
|
106 |
+
|
107 |
+
conditions[n, k] = 1
|
108 |
+
|
109 |
+
# conditions will looks like:
|
110 |
+
# [[1, 0, 0, 0],
|
111 |
+
# [0, 1, 0, 0],
|
112 |
+
# [0, 0, 1, 0],
|
113 |
+
# [0, 0, 0, 1],
|
114 |
+
# [1, 0, 0, 0],
|
115 |
+
# [0, 1, 0, 0],
|
116 |
+
# ...,
|
117 |
+
# ]
|
118 |
+
|
119 |
+
targets = torch.stack(targets, dim=0)
|
120 |
+
# targets: (batch_size, channels_num, segment_samples)
|
121 |
+
|
122 |
+
input_dict = {
|
123 |
+
'waveform': mixtures,
|
124 |
+
'condition': conditions,
|
125 |
+
}
|
126 |
+
|
127 |
+
target_dict = {'waveform': targets}
|
128 |
+
|
129 |
+
return input_dict, target_dict
|
130 |
+
|
131 |
+
|
132 |
+
def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object:
|
133 |
+
r"""Get batch data preprocessor class."""
|
134 |
+
if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor':
|
135 |
+
return BasicBatchDataPreprocessor
|
136 |
+
|
137 |
+
elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor':
|
138 |
+
return ConditionalSisoBatchDataPreprocessor
|
139 |
+
|
140 |
+
else:
|
141 |
+
raise NotImplementedError
|
bytesep/data/data_modules.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, NoReturn, Optional
|
2 |
+
|
3 |
+
import h5py
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from pytorch_lightning.core.datamodule import LightningDataModule
|
8 |
+
|
9 |
+
from bytesep.data.samplers import DistributedSamplerWrapper
|
10 |
+
from bytesep.utils import int16_to_float32
|
11 |
+
|
12 |
+
|
13 |
+
class DataModule(LightningDataModule):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
train_sampler: object,
|
17 |
+
train_dataset: object,
|
18 |
+
num_workers: int,
|
19 |
+
distributed: bool,
|
20 |
+
):
|
21 |
+
r"""Data module.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
train_sampler: Sampler object
|
25 |
+
train_dataset: Dataset object
|
26 |
+
num_workers: int
|
27 |
+
distributed: bool
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self._train_sampler = train_sampler
|
31 |
+
self.train_dataset = train_dataset
|
32 |
+
self.num_workers = num_workers
|
33 |
+
self.distributed = distributed
|
34 |
+
|
35 |
+
def setup(self, stage: Optional[str] = None) -> NoReturn:
|
36 |
+
r"""called on every device."""
|
37 |
+
|
38 |
+
# SegmentSampler is used for selecting segments for training.
|
39 |
+
# On multiple devices, each SegmentSampler samples a part of mini-batch
|
40 |
+
# data.
|
41 |
+
if self.distributed:
|
42 |
+
self.train_sampler = DistributedSamplerWrapper(self._train_sampler)
|
43 |
+
|
44 |
+
else:
|
45 |
+
self.train_sampler = self._train_sampler
|
46 |
+
|
47 |
+
def train_dataloader(self) -> torch.utils.data.DataLoader:
|
48 |
+
r"""Get train loader."""
|
49 |
+
train_loader = torch.utils.data.DataLoader(
|
50 |
+
dataset=self.train_dataset,
|
51 |
+
batch_sampler=self.train_sampler,
|
52 |
+
collate_fn=collate_fn,
|
53 |
+
num_workers=self.num_workers,
|
54 |
+
pin_memory=True,
|
55 |
+
)
|
56 |
+
|
57 |
+
return train_loader
|
58 |
+
|
59 |
+
|
60 |
+
class Dataset:
|
61 |
+
def __init__(self, augmentor: object, segment_samples: int):
|
62 |
+
r"""Used for getting data according to a meta.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
augmentor: Augmentor class
|
66 |
+
segment_samples: int
|
67 |
+
"""
|
68 |
+
self.augmentor = augmentor
|
69 |
+
self.segment_samples = segment_samples
|
70 |
+
|
71 |
+
def __getitem__(self, meta: Dict) -> Dict:
|
72 |
+
r"""Return data according to a meta. E.g., an input meta looks like: {
|
73 |
+
'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
|
74 |
+
'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}.
|
75 |
+
}
|
76 |
+
|
77 |
+
Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation).
|
78 |
+
Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation).
|
79 |
+
Finally, mixture is created by summing vocals and accompaniment.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
meta: dict, e.g., {
|
83 |
+
'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
|
84 |
+
'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}
|
85 |
+
}
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
data_dict: dict, e.g., {
|
89 |
+
'vocals': (channels, segments_num),
|
90 |
+
'accompaniment': (channels, segments_num),
|
91 |
+
'mixture': (channels, segments_num),
|
92 |
+
}
|
93 |
+
"""
|
94 |
+
source_types = meta.keys()
|
95 |
+
data_dict = {}
|
96 |
+
|
97 |
+
for source_type in source_types:
|
98 |
+
# E.g., ['vocals', 'bass', ...]
|
99 |
+
|
100 |
+
waveforms = [] # Audio segments to be mix-audio augmented.
|
101 |
+
|
102 |
+
for m in meta[source_type]:
|
103 |
+
# E.g., {
|
104 |
+
# 'hdf5_path': '.../song_A.h5',
|
105 |
+
# 'key_in_hdf5': 'vocals',
|
106 |
+
# 'begin_sample': '13406400',
|
107 |
+
# 'end_sample': 13538700,
|
108 |
+
# }
|
109 |
+
|
110 |
+
hdf5_path = m['hdf5_path']
|
111 |
+
key_in_hdf5 = m['key_in_hdf5']
|
112 |
+
bgn_sample = m['begin_sample']
|
113 |
+
end_sample = m['end_sample']
|
114 |
+
|
115 |
+
with h5py.File(hdf5_path, 'r') as hf:
|
116 |
+
|
117 |
+
if source_type == 'audioset':
|
118 |
+
index_in_hdf5 = m['index_in_hdf5']
|
119 |
+
waveform = int16_to_float32(
|
120 |
+
hf['waveform'][index_in_hdf5][bgn_sample:end_sample]
|
121 |
+
)
|
122 |
+
waveform = waveform[None, :]
|
123 |
+
else:
|
124 |
+
waveform = int16_to_float32(
|
125 |
+
hf[key_in_hdf5][:, bgn_sample:end_sample]
|
126 |
+
)
|
127 |
+
|
128 |
+
if self.augmentor:
|
129 |
+
waveform = self.augmentor(waveform, source_type)
|
130 |
+
|
131 |
+
waveform = librosa.util.fix_length(
|
132 |
+
waveform, size=self.segment_samples, axis=1
|
133 |
+
)
|
134 |
+
# (channels_num, segments_num)
|
135 |
+
|
136 |
+
waveforms.append(waveform)
|
137 |
+
# E.g., waveforms: [(channels_num, audio_samples), (channels_num, audio_samples)]
|
138 |
+
|
139 |
+
# mix-audio augmentation
|
140 |
+
data_dict[source_type] = np.sum(waveforms, axis=0)
|
141 |
+
# data_dict[source_type]: (channels_num, audio_samples)
|
142 |
+
|
143 |
+
# data_dict looks like: {
|
144 |
+
# 'voclas': (channels_num, audio_samples),
|
145 |
+
# 'accompaniment': (channels_num, audio_samples)
|
146 |
+
# }
|
147 |
+
|
148 |
+
# Mix segments from different sources.
|
149 |
+
mixture = np.sum(
|
150 |
+
[data_dict[source_type] for source_type in source_types], axis=0
|
151 |
+
)
|
152 |
+
data_dict['mixture'] = mixture
|
153 |
+
# shape: (channels_num, audio_samples)
|
154 |
+
|
155 |
+
return data_dict
|
156 |
+
|
157 |
+
|
158 |
+
def collate_fn(list_data_dict: List[Dict]) -> Dict:
|
159 |
+
r"""Collate mini-batch data to inputs and targets for training.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
list_data_dict: e.g., [
|
163 |
+
{'vocals': (channels_num, segment_samples),
|
164 |
+
'accompaniment': (channels_num, segment_samples),
|
165 |
+
'mixture': (channels_num, segment_samples)
|
166 |
+
},
|
167 |
+
{'vocals': (channels_num, segment_samples),
|
168 |
+
'accompaniment': (channels_num, segment_samples),
|
169 |
+
'mixture': (channels_num, segment_samples)
|
170 |
+
},
|
171 |
+
...]
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
data_dict: e.g. {
|
175 |
+
'vocals': (batch_size, channels_num, segment_samples),
|
176 |
+
'accompaniment': (batch_size, channels_num, segment_samples),
|
177 |
+
'mixture': (batch_size, channels_num, segment_samples)
|
178 |
+
}
|
179 |
+
"""
|
180 |
+
data_dict = {}
|
181 |
+
|
182 |
+
for key in list_data_dict[0].keys():
|
183 |
+
data_dict[key] = torch.Tensor(
|
184 |
+
np.array([data_dict[key] for data_dict in list_data_dict])
|
185 |
+
)
|
186 |
+
|
187 |
+
return data_dict
|
bytesep/data/samplers.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from typing import Dict, List, NoReturn
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
class SegmentSampler:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
indexes_path: str,
|
12 |
+
segment_samples: int,
|
13 |
+
mixaudio_dict: Dict,
|
14 |
+
batch_size: int,
|
15 |
+
steps_per_epoch: int,
|
16 |
+
random_seed=1234,
|
17 |
+
):
|
18 |
+
r"""Sample training indexes of sources.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
indexes_path: str, path of indexes dict
|
22 |
+
segment_samplers: int
|
23 |
+
mixaudio_dict, dict, including hyper-parameters for mix-audio data
|
24 |
+
augmentation, e.g., {'voclas': 2, 'accompaniment': 2}
|
25 |
+
batch_size: int
|
26 |
+
steps_per_epoch: int, #steps_per_epoch is called an `epoch`
|
27 |
+
random_seed: int
|
28 |
+
"""
|
29 |
+
self.segment_samples = segment_samples
|
30 |
+
self.mixaudio_dict = mixaudio_dict
|
31 |
+
self.batch_size = batch_size
|
32 |
+
self.steps_per_epoch = steps_per_epoch
|
33 |
+
|
34 |
+
self.meta_dict = pickle.load(open(indexes_path, "rb"))
|
35 |
+
# E.g., {
|
36 |
+
# 'vocals': [
|
37 |
+
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300},
|
38 |
+
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
|
39 |
+
# ...
|
40 |
+
# ],
|
41 |
+
# 'accompaniment': [
|
42 |
+
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300},
|
43 |
+
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
|
44 |
+
# ...
|
45 |
+
# ]
|
46 |
+
# }
|
47 |
+
|
48 |
+
self.source_types = self.meta_dict.keys()
|
49 |
+
# E.g., ['vocals', 'accompaniment']
|
50 |
+
|
51 |
+
self.pointers_dict = {source_type: 0 for source_type in self.source_types}
|
52 |
+
# E.g., {'vocals': 0, 'accompaniment': 0}
|
53 |
+
|
54 |
+
self.indexes_dict = {
|
55 |
+
source_type: np.arange(len(self.meta_dict[source_type]))
|
56 |
+
for source_type in self.source_types
|
57 |
+
}
|
58 |
+
# E.g. {
|
59 |
+
# 'vocals': [0, 1, ..., 225751],
|
60 |
+
# 'accompaniment': [0, 1, ..., 225751]
|
61 |
+
# }
|
62 |
+
|
63 |
+
self.random_state = np.random.RandomState(random_seed)
|
64 |
+
|
65 |
+
# Shuffle indexes.
|
66 |
+
for source_type in self.source_types:
|
67 |
+
self.random_state.shuffle(self.indexes_dict[source_type])
|
68 |
+
print("{}: {}".format(source_type, len(self.indexes_dict[source_type])))
|
69 |
+
|
70 |
+
def __iter__(self) -> List[Dict]:
|
71 |
+
r"""Yield a batch of meta info.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
batch_meta_list: (batch_size,) e.g., when mix-audio is 2, looks like [
|
75 |
+
{'vocals': [
|
76 |
+
{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
|
77 |
+
{'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
|
78 |
+
'accompaniment': [
|
79 |
+
{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 14579460, 'end_sample': 14711760},
|
80 |
+
{'hdf5_path': 'songF.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 3995460, 'end_sample': 4127760}]
|
81 |
+
}
|
82 |
+
...
|
83 |
+
]
|
84 |
+
"""
|
85 |
+
batch_size = self.batch_size
|
86 |
+
|
87 |
+
while True:
|
88 |
+
batch_meta_dict = {source_type: [] for source_type in self.source_types}
|
89 |
+
|
90 |
+
for source_type in self.source_types:
|
91 |
+
# E.g., ['vocals', 'accompaniment']
|
92 |
+
|
93 |
+
# Loop until get a mini-batch.
|
94 |
+
while len(batch_meta_dict[source_type]) != batch_size:
|
95 |
+
|
96 |
+
largest_index = (
|
97 |
+
len(self.indexes_dict[source_type])
|
98 |
+
- self.mixaudio_dict[source_type]
|
99 |
+
)
|
100 |
+
# E.g., 225750 = 225752 - 2
|
101 |
+
|
102 |
+
if self.pointers_dict[source_type] > largest_index:
|
103 |
+
|
104 |
+
# Reset pointer, and shuffle indexes.
|
105 |
+
self.pointers_dict[source_type] = 0
|
106 |
+
self.random_state.shuffle(self.indexes_dict[source_type])
|
107 |
+
|
108 |
+
source_metas = []
|
109 |
+
mix_audios_num = self.mixaudio_dict[source_type]
|
110 |
+
|
111 |
+
for _ in range(mix_audios_num):
|
112 |
+
|
113 |
+
pointer = self.pointers_dict[source_type]
|
114 |
+
# E.g., 1
|
115 |
+
|
116 |
+
index = self.indexes_dict[source_type][pointer]
|
117 |
+
# E.g., 12231
|
118 |
+
|
119 |
+
self.pointers_dict[source_type] += 1
|
120 |
+
|
121 |
+
source_meta = self.meta_dict[source_type][index]
|
122 |
+
# E.g., ['song_A.h5', 198450, 330750]
|
123 |
+
|
124 |
+
# source_metas.append(new_source_meta)
|
125 |
+
source_metas.append(source_meta)
|
126 |
+
|
127 |
+
batch_meta_dict[source_type].append(source_metas)
|
128 |
+
# When mix-audio is 2, batch_meta_dict looks like: {
|
129 |
+
# 'vocals': [
|
130 |
+
# [{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
|
131 |
+
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}],
|
132 |
+
# [{'hdf5_path': 'songC.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1186290, 'end_sample': 1318590},
|
133 |
+
# {'hdf5_path': 'songD.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 8462790, 'end_sample': 8595090}]
|
134 |
+
# ]
|
135 |
+
# 'accompaniment': [
|
136 |
+
# [{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 24232950, 'end_sample': 24365250},
|
137 |
+
# {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1569960, 'end_sample': 1702260}],
|
138 |
+
# [{'hdf5_path': 'songG.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 2795940, 'end_sample': 2928240},
|
139 |
+
# {'hdf5_path': 'songH.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 10923570, 'end_sample': 11055870}]
|
140 |
+
# ]
|
141 |
+
# }
|
142 |
+
|
143 |
+
batch_meta_list = [
|
144 |
+
{
|
145 |
+
source_type: batch_meta_dict[source_type][i]
|
146 |
+
for source_type in self.source_types
|
147 |
+
}
|
148 |
+
for i in range(batch_size)
|
149 |
+
]
|
150 |
+
# When mix-audio is 2, batch_meta_list looks like: [
|
151 |
+
# {'vocals': [
|
152 |
+
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
|
153 |
+
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
|
154 |
+
# 'accompaniment': [
|
155 |
+
# {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 14579460, 'end_sample': 14711760},
|
156 |
+
# {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 3995460, 'end_sample': 4127760}]
|
157 |
+
# }
|
158 |
+
# ...
|
159 |
+
# ]
|
160 |
+
|
161 |
+
yield batch_meta_list
|
162 |
+
|
163 |
+
def __len__(self) -> int:
|
164 |
+
return self.steps_per_epoch
|
165 |
+
|
166 |
+
def state_dict(self) -> Dict:
|
167 |
+
state = {'pointers_dict': self.pointers_dict, 'indexes_dict': self.indexes_dict}
|
168 |
+
return state
|
169 |
+
|
170 |
+
def load_state_dict(self, state) -> NoReturn:
|
171 |
+
self.pointers_dict = state['pointers_dict']
|
172 |
+
self.indexes_dict = state['indexes_dict']
|
173 |
+
|
174 |
+
|
175 |
+
class DistributedSamplerWrapper:
|
176 |
+
def __init__(self, sampler):
|
177 |
+
r"""Distributed wrapper of sampler."""
|
178 |
+
self.sampler = sampler
|
179 |
+
|
180 |
+
def __iter__(self):
|
181 |
+
num_replicas = dist.get_world_size()
|
182 |
+
rank = dist.get_rank()
|
183 |
+
|
184 |
+
for indices in self.sampler:
|
185 |
+
yield indices[rank::num_replicas]
|
186 |
+
|
187 |
+
def __len__(self) -> int:
|
188 |
+
return len(self.sampler)
|
bytesep/dataset_creation/__init__.py
ADDED
File without changes
|
bytesep/dataset_creation/create_evaluation_audios/__init__.py
ADDED
File without changes
|
bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from typing import NoReturn
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import soundfile
|
8 |
+
|
9 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
|
10 |
+
read_csv as read_instruments_solo_csv,
|
11 |
+
)
|
12 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.maestro import (
|
13 |
+
read_csv as read_maestro_csv,
|
14 |
+
)
|
15 |
+
from bytesep.utils import load_random_segment
|
16 |
+
|
17 |
+
|
18 |
+
def create_evaluation(args) -> NoReturn:
|
19 |
+
r"""Random mix and write out audios for evaluation.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
piano_dataset_dir: str, the directory of the piano dataset
|
23 |
+
symphony_dataset_dir: str, the directory of the symphony dataset
|
24 |
+
evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
|
25 |
+
sample_rate: int
|
26 |
+
channels: int, e.g., 1 | 2
|
27 |
+
evaluation_segments_num: int
|
28 |
+
mono: bool
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
NoReturn
|
32 |
+
"""
|
33 |
+
|
34 |
+
# arguments & parameters
|
35 |
+
piano_dataset_dir = args.piano_dataset_dir
|
36 |
+
symphony_dataset_dir = args.symphony_dataset_dir
|
37 |
+
evaluation_audios_dir = args.evaluation_audios_dir
|
38 |
+
sample_rate = args.sample_rate
|
39 |
+
channels = args.channels
|
40 |
+
evaluation_segments_num = args.evaluation_segments_num
|
41 |
+
mono = True if channels == 1 else False
|
42 |
+
|
43 |
+
split = 'test'
|
44 |
+
segment_seconds = 10.0
|
45 |
+
|
46 |
+
random_state = np.random.RandomState(1234)
|
47 |
+
|
48 |
+
piano_meta_csv = os.path.join(piano_dataset_dir, 'maestro-v2.0.0.csv')
|
49 |
+
piano_names_dict = read_maestro_csv(piano_meta_csv)
|
50 |
+
piano_audio_names = piano_names_dict[split]
|
51 |
+
|
52 |
+
symphony_meta_csv = os.path.join(symphony_dataset_dir, 'validation.csv')
|
53 |
+
symphony_names_dict = read_instruments_solo_csv(symphony_meta_csv)
|
54 |
+
symphony_audio_names = symphony_names_dict[split]
|
55 |
+
|
56 |
+
for source_type in ['piano', 'symphony', 'mixture']:
|
57 |
+
output_dir = os.path.join(evaluation_audios_dir, split, source_type)
|
58 |
+
os.makedirs(output_dir, exist_ok=True)
|
59 |
+
|
60 |
+
for n in range(evaluation_segments_num):
|
61 |
+
|
62 |
+
print('{} / {}'.format(n, evaluation_segments_num))
|
63 |
+
|
64 |
+
# Randomly select and write out a clean piano segment.
|
65 |
+
piano_audio_name = random_state.choice(piano_audio_names)
|
66 |
+
piano_audio_path = os.path.join(piano_dataset_dir, piano_audio_name)
|
67 |
+
|
68 |
+
piano_audio = load_random_segment(
|
69 |
+
audio_path=piano_audio_path,
|
70 |
+
random_state=random_state,
|
71 |
+
segment_seconds=segment_seconds,
|
72 |
+
mono=mono,
|
73 |
+
sample_rate=sample_rate,
|
74 |
+
)
|
75 |
+
|
76 |
+
output_piano_path = os.path.join(
|
77 |
+
evaluation_audios_dir, split, 'piano', '{:04d}.wav'.format(n)
|
78 |
+
)
|
79 |
+
soundfile.write(
|
80 |
+
file=output_piano_path, data=piano_audio.T, samplerate=sample_rate
|
81 |
+
)
|
82 |
+
print("Write out to {}".format(output_piano_path))
|
83 |
+
|
84 |
+
# Randomly select and write out a clean symphony segment.
|
85 |
+
symphony_audio_name = random_state.choice(symphony_audio_names)
|
86 |
+
symphony_audio_path = os.path.join(
|
87 |
+
symphony_dataset_dir, "mp3s", symphony_audio_name
|
88 |
+
)
|
89 |
+
|
90 |
+
symphony_audio = load_random_segment(
|
91 |
+
audio_path=symphony_audio_path,
|
92 |
+
random_state=random_state,
|
93 |
+
segment_seconds=segment_seconds,
|
94 |
+
mono=mono,
|
95 |
+
sample_rate=sample_rate,
|
96 |
+
)
|
97 |
+
|
98 |
+
output_symphony_path = os.path.join(
|
99 |
+
evaluation_audios_dir, split, 'symphony', '{:04d}.wav'.format(n)
|
100 |
+
)
|
101 |
+
soundfile.write(
|
102 |
+
file=output_symphony_path, data=symphony_audio.T, samplerate=sample_rate
|
103 |
+
)
|
104 |
+
print("Write out to {}".format(output_symphony_path))
|
105 |
+
|
106 |
+
# Mix piano and symphony segments and write out a mixture segment.
|
107 |
+
mixture_audio = symphony_audio + piano_audio
|
108 |
+
output_mixture_path = os.path.join(
|
109 |
+
evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
|
110 |
+
)
|
111 |
+
soundfile.write(
|
112 |
+
file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
|
113 |
+
)
|
114 |
+
print("Write out to {}".format(output_mixture_path))
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
parser = argparse.ArgumentParser()
|
119 |
+
|
120 |
+
parser.add_argument(
|
121 |
+
"--piano_dataset_dir",
|
122 |
+
type=str,
|
123 |
+
required=True,
|
124 |
+
help="The directory of the piano dataset.",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--symphony_dataset_dir",
|
128 |
+
type=str,
|
129 |
+
required=True,
|
130 |
+
help="The directory of the symphony dataset.",
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"--evaluation_audios_dir",
|
134 |
+
type=str,
|
135 |
+
required=True,
|
136 |
+
help="The directory to write out randomly selected and mixed audio segments.",
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--sample_rate",
|
140 |
+
type=int,
|
141 |
+
required=True,
|
142 |
+
help="Sample rate.",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--channels",
|
146 |
+
type=int,
|
147 |
+
required=True,
|
148 |
+
help="Audio channels, e.g, 1 or 2.",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--evaluation_segments_num",
|
152 |
+
type=int,
|
153 |
+
required=True,
|
154 |
+
help="The number of segments to create for evaluation.",
|
155 |
+
)
|
156 |
+
|
157 |
+
# Parse arguments.
|
158 |
+
args = parser.parse_args()
|
159 |
+
|
160 |
+
create_evaluation(args)
|
bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import soundfile
|
4 |
+
from typing import NoReturn
|
5 |
+
|
6 |
+
import musdb
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from bytesep.utils import load_audio
|
10 |
+
|
11 |
+
|
12 |
+
def create_evaluation(args) -> NoReturn:
|
13 |
+
r"""Random mix and write out audios for evaluation.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
vctk_dataset_dir: str, the directory of the VCTK dataset
|
17 |
+
symphony_dataset_dir: str, the directory of the symphony dataset
|
18 |
+
evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
|
19 |
+
sample_rate: int
|
20 |
+
channels: int, e.g., 1 | 2
|
21 |
+
evaluation_segments_num: int
|
22 |
+
mono: bool
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
NoReturn
|
26 |
+
"""
|
27 |
+
|
28 |
+
# arguments & parameters
|
29 |
+
vctk_dataset_dir = args.vctk_dataset_dir
|
30 |
+
musdb18_dataset_dir = args.musdb18_dataset_dir
|
31 |
+
evaluation_audios_dir = args.evaluation_audios_dir
|
32 |
+
sample_rate = args.sample_rate
|
33 |
+
channels = args.channels
|
34 |
+
evaluation_segments_num = args.evaluation_segments_num
|
35 |
+
mono = True if channels == 1 else False
|
36 |
+
|
37 |
+
split = 'test'
|
38 |
+
random_state = np.random.RandomState(1234)
|
39 |
+
|
40 |
+
# paths
|
41 |
+
audios_dir = os.path.join(vctk_dataset_dir, "wav48", split)
|
42 |
+
|
43 |
+
for source_type in ['speech', 'music', 'mixture']:
|
44 |
+
output_dir = os.path.join(evaluation_audios_dir, split, source_type)
|
45 |
+
os.makedirs(output_dir, exist_ok=True)
|
46 |
+
|
47 |
+
# Get VCTK audio paths.
|
48 |
+
speech_audio_paths = []
|
49 |
+
speaker_ids = sorted(os.listdir(audios_dir))
|
50 |
+
|
51 |
+
for speaker_id in speaker_ids:
|
52 |
+
speaker_audios_dir = os.path.join(audios_dir, speaker_id)
|
53 |
+
|
54 |
+
audio_names = sorted(os.listdir(speaker_audios_dir))
|
55 |
+
|
56 |
+
for audio_name in audio_names:
|
57 |
+
speaker_audio_path = os.path.join(speaker_audios_dir, audio_name)
|
58 |
+
speech_audio_paths.append(speaker_audio_path)
|
59 |
+
|
60 |
+
# Get Musdb18 audio paths.
|
61 |
+
mus = musdb.DB(root=musdb18_dataset_dir, subsets=[split])
|
62 |
+
track_indexes = np.arange(len(mus.tracks))
|
63 |
+
|
64 |
+
for n in range(evaluation_segments_num):
|
65 |
+
|
66 |
+
print('{} / {}'.format(n, evaluation_segments_num))
|
67 |
+
|
68 |
+
# Randomly select and write out a clean speech segment.
|
69 |
+
speech_audio_path = random_state.choice(speech_audio_paths)
|
70 |
+
|
71 |
+
speech_audio = load_audio(
|
72 |
+
audio_path=speech_audio_path, mono=mono, sample_rate=sample_rate
|
73 |
+
)
|
74 |
+
# (channels_num, audio_samples)
|
75 |
+
|
76 |
+
if channels == 2:
|
77 |
+
speech_audio = np.tile(speech_audio, (2, 1))
|
78 |
+
# (channels_num, audio_samples)
|
79 |
+
|
80 |
+
output_speech_path = os.path.join(
|
81 |
+
evaluation_audios_dir, split, 'speech', '{:04d}.wav'.format(n)
|
82 |
+
)
|
83 |
+
soundfile.write(
|
84 |
+
file=output_speech_path, data=speech_audio.T, samplerate=sample_rate
|
85 |
+
)
|
86 |
+
print("Write out to {}".format(output_speech_path))
|
87 |
+
|
88 |
+
# Randomly select and write out a clean music segment.
|
89 |
+
track_index = random_state.choice(track_indexes)
|
90 |
+
track = mus[track_index]
|
91 |
+
|
92 |
+
segment_samples = speech_audio.shape[1]
|
93 |
+
start_sample = int(
|
94 |
+
random_state.uniform(0.0, segment_samples - speech_audio.shape[1])
|
95 |
+
)
|
96 |
+
|
97 |
+
music_audio = track.audio[start_sample : start_sample + segment_samples, :].T
|
98 |
+
# (channels_num, audio_samples)
|
99 |
+
|
100 |
+
output_music_path = os.path.join(
|
101 |
+
evaluation_audios_dir, split, 'music', '{:04d}.wav'.format(n)
|
102 |
+
)
|
103 |
+
soundfile.write(
|
104 |
+
file=output_music_path, data=music_audio.T, samplerate=sample_rate
|
105 |
+
)
|
106 |
+
print("Write out to {}".format(output_music_path))
|
107 |
+
|
108 |
+
# Mix speech and music segments and write out a mixture segment.
|
109 |
+
mixture_audio = speech_audio + music_audio
|
110 |
+
# (channels_num, audio_samples)
|
111 |
+
|
112 |
+
output_mixture_path = os.path.join(
|
113 |
+
evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
|
114 |
+
)
|
115 |
+
soundfile.write(
|
116 |
+
file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
|
117 |
+
)
|
118 |
+
print("Write out to {}".format(output_mixture_path))
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
parser = argparse.ArgumentParser()
|
123 |
+
|
124 |
+
parser.add_argument(
|
125 |
+
"--vctk_dataset_dir",
|
126 |
+
type=str,
|
127 |
+
required=True,
|
128 |
+
help="The directory of the VCTK dataset.",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--musdb18_dataset_dir",
|
132 |
+
type=str,
|
133 |
+
required=True,
|
134 |
+
help="The directory of the MUSDB18 dataset.",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--evaluation_audios_dir",
|
138 |
+
type=str,
|
139 |
+
required=True,
|
140 |
+
help="The directory to write out randomly selected and mixed audio segments.",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--sample_rate",
|
144 |
+
type=int,
|
145 |
+
required=True,
|
146 |
+
help="Sample rate",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--channels",
|
150 |
+
type=int,
|
151 |
+
required=True,
|
152 |
+
help="Audio channels, e.g, 1 or 2.",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--evaluation_segments_num",
|
156 |
+
type=int,
|
157 |
+
required=True,
|
158 |
+
help="The number of segments to create for evaluation.",
|
159 |
+
)
|
160 |
+
|
161 |
+
# Parse arguments.
|
162 |
+
args = parser.parse_args()
|
163 |
+
|
164 |
+
create_evaluation(args)
|
bytesep/dataset_creation/create_evaluation_audios/violin-piano.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from typing import NoReturn
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import soundfile
|
8 |
+
|
9 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
|
10 |
+
read_csv as read_instruments_solo_csv,
|
11 |
+
)
|
12 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.maestro import (
|
13 |
+
read_csv as read_maestro_csv,
|
14 |
+
)
|
15 |
+
from bytesep.utils import load_random_segment
|
16 |
+
|
17 |
+
|
18 |
+
def create_evaluation(args) -> NoReturn:
|
19 |
+
r"""Random mix and write out audios for evaluation.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
violin_dataset_dir: str, the directory of the violin dataset
|
23 |
+
piano_dataset_dir: str, the directory of the piano dataset
|
24 |
+
evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
|
25 |
+
sample_rate: int
|
26 |
+
channels: int, e.g., 1 | 2
|
27 |
+
evaluation_segments_num: int
|
28 |
+
mono: bool
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
NoReturn
|
32 |
+
"""
|
33 |
+
|
34 |
+
# arguments & parameters
|
35 |
+
violin_dataset_dir = args.violin_dataset_dir
|
36 |
+
piano_dataset_dir = args.piano_dataset_dir
|
37 |
+
evaluation_audios_dir = args.evaluation_audios_dir
|
38 |
+
sample_rate = args.sample_rate
|
39 |
+
channels = args.channels
|
40 |
+
evaluation_segments_num = args.evaluation_segments_num
|
41 |
+
mono = True if channels == 1 else False
|
42 |
+
|
43 |
+
split = 'test'
|
44 |
+
segment_seconds = 10.0
|
45 |
+
|
46 |
+
random_state = np.random.RandomState(1234)
|
47 |
+
|
48 |
+
violin_meta_csv = os.path.join(violin_dataset_dir, 'validation.csv')
|
49 |
+
violin_names_dict = read_instruments_solo_csv(violin_meta_csv)
|
50 |
+
violin_audio_names = violin_names_dict['{}'.format(split)]
|
51 |
+
|
52 |
+
piano_meta_csv = os.path.join(piano_dataset_dir, 'maestro-v2.0.0.csv')
|
53 |
+
piano_names_dict = read_maestro_csv(piano_meta_csv)
|
54 |
+
piano_audio_names = piano_names_dict['{}'.format(split)]
|
55 |
+
|
56 |
+
for source_type in ['violin', 'piano', 'mixture']:
|
57 |
+
output_dir = os.path.join(evaluation_audios_dir, split, source_type)
|
58 |
+
os.makedirs(output_dir, exist_ok=True)
|
59 |
+
|
60 |
+
for n in range(evaluation_segments_num):
|
61 |
+
|
62 |
+
print('{} / {}'.format(n, evaluation_segments_num))
|
63 |
+
|
64 |
+
# Randomly select and write out a clean violin segment.
|
65 |
+
violin_audio_name = random_state.choice(violin_audio_names)
|
66 |
+
violin_audio_path = os.path.join(violin_dataset_dir, "mp3s", violin_audio_name)
|
67 |
+
|
68 |
+
violin_audio = load_random_segment(
|
69 |
+
audio_path=violin_audio_path,
|
70 |
+
random_state=random_state,
|
71 |
+
segment_seconds=segment_seconds,
|
72 |
+
mono=mono,
|
73 |
+
sample_rate=sample_rate,
|
74 |
+
)
|
75 |
+
# (channels_num, audio_samples)
|
76 |
+
|
77 |
+
output_violin_path = os.path.join(
|
78 |
+
evaluation_audios_dir, split, 'violin', '{:04d}.wav'.format(n)
|
79 |
+
)
|
80 |
+
soundfile.write(
|
81 |
+
file=output_violin_path, data=violin_audio.T, samplerate=sample_rate
|
82 |
+
)
|
83 |
+
print("Write out to {}".format(output_violin_path))
|
84 |
+
|
85 |
+
# Randomly select and write out a clean piano segment.
|
86 |
+
piano_audio_name = random_state.choice(piano_audio_names)
|
87 |
+
piano_audio_path = os.path.join(piano_dataset_dir, piano_audio_name)
|
88 |
+
|
89 |
+
piano_audio = load_random_segment(
|
90 |
+
audio_path=piano_audio_path,
|
91 |
+
random_state=random_state,
|
92 |
+
segment_seconds=segment_seconds,
|
93 |
+
mono=mono,
|
94 |
+
sample_rate=sample_rate,
|
95 |
+
)
|
96 |
+
# (channels_num, audio_samples)
|
97 |
+
|
98 |
+
output_piano_path = os.path.join(
|
99 |
+
evaluation_audios_dir, split, 'piano', '{:04d}.wav'.format(n)
|
100 |
+
)
|
101 |
+
soundfile.write(
|
102 |
+
file=output_piano_path, data=piano_audio.T, samplerate=sample_rate
|
103 |
+
)
|
104 |
+
print("Write out to {}".format(output_piano_path))
|
105 |
+
|
106 |
+
# Mix violin and piano segments and write out a mixture segment.
|
107 |
+
mixture_audio = violin_audio + piano_audio
|
108 |
+
# (channels_num, audio_samples)
|
109 |
+
|
110 |
+
output_mixture_path = os.path.join(
|
111 |
+
evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
|
112 |
+
)
|
113 |
+
soundfile.write(
|
114 |
+
file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
|
115 |
+
)
|
116 |
+
print("Write out to {}".format(output_mixture_path))
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
parser = argparse.ArgumentParser()
|
121 |
+
|
122 |
+
parser.add_argument(
|
123 |
+
"--violin_dataset_dir",
|
124 |
+
type=str,
|
125 |
+
required=True,
|
126 |
+
help="The directory of the violin dataset.",
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--piano_dataset_dir",
|
130 |
+
type=str,
|
131 |
+
required=True,
|
132 |
+
help="The directory of the piano dataset.",
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--evaluation_audios_dir",
|
136 |
+
type=str,
|
137 |
+
required=True,
|
138 |
+
help="The directory to write out randomly selected and mixed audio segments.",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--sample_rate",
|
142 |
+
type=int,
|
143 |
+
required=True,
|
144 |
+
help="Sample rate",
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--channels",
|
148 |
+
type=int,
|
149 |
+
required=True,
|
150 |
+
help="Audio channels, e.g, 1 or 2.",
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--evaluation_segments_num",
|
154 |
+
type=int,
|
155 |
+
required=True,
|
156 |
+
help="The number of segments to create for evaluation.",
|
157 |
+
)
|
158 |
+
|
159 |
+
# Parse arguments.
|
160 |
+
args = parser.parse_args()
|
161 |
+
|
162 |
+
create_evaluation(args)
|
bytesep/dataset_creation/create_indexes/__init__.py
ADDED
File without changes
|
bytesep/dataset_creation/create_indexes/create_indexes.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
from typing import NoReturn
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
|
8 |
+
from bytesep.utils import read_yaml
|
9 |
+
|
10 |
+
|
11 |
+
def create_indexes(args) -> NoReturn:
|
12 |
+
r"""Create and write out training indexes into disk. The indexes may contain
|
13 |
+
information from multiple datasets. During training, training indexes will
|
14 |
+
be shuffled and iterated for selecting segments to be mixed. E.g., the
|
15 |
+
training indexes_dict looks like: {
|
16 |
+
'vocals': [
|
17 |
+
{'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
|
18 |
+
{'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
|
19 |
+
...
|
20 |
+
]
|
21 |
+
'accompaniment': [
|
22 |
+
{'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
|
23 |
+
{'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
|
24 |
+
...
|
25 |
+
]
|
26 |
+
}
|
27 |
+
"""
|
28 |
+
|
29 |
+
# Arugments & parameters
|
30 |
+
workspace = args.workspace
|
31 |
+
config_yaml = args.config_yaml
|
32 |
+
|
33 |
+
# Only create indexes for training, because evalution is on entire pieces.
|
34 |
+
split = "train"
|
35 |
+
|
36 |
+
# Read config file.
|
37 |
+
configs = read_yaml(config_yaml)
|
38 |
+
|
39 |
+
sample_rate = configs["sample_rate"]
|
40 |
+
segment_samples = int(configs["segment_seconds"] * sample_rate)
|
41 |
+
|
42 |
+
# Path to write out index.
|
43 |
+
indexes_path = os.path.join(workspace, configs[split]["indexes"])
|
44 |
+
os.makedirs(os.path.dirname(indexes_path), exist_ok=True)
|
45 |
+
|
46 |
+
source_types = configs[split]["source_types"].keys()
|
47 |
+
# E.g., ['vocals', 'accompaniment']
|
48 |
+
|
49 |
+
indexes_dict = {source_type: [] for source_type in source_types}
|
50 |
+
# E.g., indexes_dict will looks like: {
|
51 |
+
# 'vocals': [
|
52 |
+
# {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
|
53 |
+
# {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
|
54 |
+
# ...
|
55 |
+
# ]
|
56 |
+
# 'accompaniment': [
|
57 |
+
# {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
|
58 |
+
# {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
|
59 |
+
# ...
|
60 |
+
# ]
|
61 |
+
# }
|
62 |
+
|
63 |
+
# Get training indexes for each source type.
|
64 |
+
for source_type in source_types:
|
65 |
+
# E.g., ['vocals', 'bass', ...]
|
66 |
+
|
67 |
+
print("--- {} ---".format(source_type))
|
68 |
+
|
69 |
+
dataset_types = configs[split]["source_types"][source_type]
|
70 |
+
# E.g., ['musdb18', ...]
|
71 |
+
|
72 |
+
# Each source can come from mulitple datasets.
|
73 |
+
for dataset_type in dataset_types:
|
74 |
+
|
75 |
+
hdf5s_dir = os.path.join(
|
76 |
+
workspace, dataset_types[dataset_type]["hdf5s_directory"]
|
77 |
+
)
|
78 |
+
|
79 |
+
hop_samples = int(dataset_types[dataset_type]["hop_seconds"] * sample_rate)
|
80 |
+
|
81 |
+
key_in_hdf5 = dataset_types[dataset_type]["key_in_hdf5"]
|
82 |
+
# E.g., 'vocals'
|
83 |
+
|
84 |
+
hdf5_names = sorted(os.listdir(hdf5s_dir))
|
85 |
+
print("Hdf5 files num: {}".format(len(hdf5_names)))
|
86 |
+
|
87 |
+
# Traverse all packed hdf5 files of a dataset.
|
88 |
+
for n, hdf5_name in enumerate(hdf5_names):
|
89 |
+
|
90 |
+
print(n, hdf5_name)
|
91 |
+
hdf5_path = os.path.join(hdf5s_dir, hdf5_name)
|
92 |
+
|
93 |
+
with h5py.File(hdf5_path, "r") as hf:
|
94 |
+
|
95 |
+
bgn_sample = 0
|
96 |
+
while bgn_sample + segment_samples < hf[key_in_hdf5].shape[-1]:
|
97 |
+
meta = {
|
98 |
+
'hdf5_path': hdf5_path,
|
99 |
+
'key_in_hdf5': key_in_hdf5,
|
100 |
+
'begin_sample': bgn_sample,
|
101 |
+
'end_sample': bgn_sample + segment_samples,
|
102 |
+
}
|
103 |
+
indexes_dict[source_type].append(meta)
|
104 |
+
|
105 |
+
bgn_sample += hop_samples
|
106 |
+
|
107 |
+
# If the audio length is shorter than the segment length,
|
108 |
+
# then use the entire audio as a segment.
|
109 |
+
if bgn_sample == 0:
|
110 |
+
meta = {
|
111 |
+
'hdf5_path': hdf5_path,
|
112 |
+
'key_in_hdf5': key_in_hdf5,
|
113 |
+
'begin_sample': 0,
|
114 |
+
'end_sample': segment_samples,
|
115 |
+
}
|
116 |
+
indexes_dict[source_type].append(meta)
|
117 |
+
|
118 |
+
print(
|
119 |
+
"Total indexes for {}: {}".format(
|
120 |
+
source_type, len(indexes_dict[source_type])
|
121 |
+
)
|
122 |
+
)
|
123 |
+
|
124 |
+
pickle.dump(indexes_dict, open(indexes_path, "wb"))
|
125 |
+
print("Write index dict to {}".format(indexes_path))
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
parser = argparse.ArgumentParser()
|
130 |
+
|
131 |
+
parser.add_argument(
|
132 |
+
"--workspace", type=str, required=True, help="Directory of workspace."
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--config_yaml", type=str, required=True, help="User defined config file."
|
136 |
+
)
|
137 |
+
|
138 |
+
# Parse arguments.
|
139 |
+
args = parser.parse_args()
|
140 |
+
|
141 |
+
# Create training indexes.
|
142 |
+
create_indexes(args)
|
bytesep/dataset_creation/pack_audios_to_hdf5s/__init__.py
ADDED
File without changes
|
bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import time
|
5 |
+
from concurrent.futures import ProcessPoolExecutor
|
6 |
+
from typing import Dict, List, NoReturn
|
7 |
+
|
8 |
+
import h5py
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
from bytesep.utils import float32_to_int16, load_audio
|
13 |
+
|
14 |
+
|
15 |
+
def read_csv(meta_csv) -> Dict:
|
16 |
+
r"""Get train & test names from csv.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
meta_csv: str
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
names_dict: dict, e.g., {
|
23 |
+
'train', ['songA.mp3', 'songB.mp3', ...],
|
24 |
+
'test': ['songE.mp3', 'songF.mp3', ...]
|
25 |
+
}
|
26 |
+
"""
|
27 |
+
df = pd.read_csv(meta_csv, sep=',')
|
28 |
+
|
29 |
+
names_dict = {}
|
30 |
+
|
31 |
+
for split in ['train', 'test']:
|
32 |
+
audio_indexes = df['split'] == split
|
33 |
+
audio_names = list(df['audio_name'][audio_indexes])
|
34 |
+
audio_names = [
|
35 |
+
'{}.mp3'.format(pathlib.Path(audio_name).stem) for audio_name in audio_names
|
36 |
+
]
|
37 |
+
names_dict[split] = audio_names
|
38 |
+
|
39 |
+
return names_dict
|
40 |
+
|
41 |
+
|
42 |
+
def pack_audios_to_hdf5s(args) -> NoReturn:
|
43 |
+
r"""Pack (resampled) audio files into hdf5 files to speed up loading.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
dataset_dir: str
|
47 |
+
split: str, 'train' | 'test'
|
48 |
+
source_type: str
|
49 |
+
hdf5s_dir: str, directory to write out hdf5 files
|
50 |
+
sample_rate: int
|
51 |
+
channels_num: int
|
52 |
+
mono: bool
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
NoReturn
|
56 |
+
"""
|
57 |
+
|
58 |
+
# arguments & parameters
|
59 |
+
dataset_dir = args.dataset_dir
|
60 |
+
split = args.split
|
61 |
+
source_type = args.source_type
|
62 |
+
hdf5s_dir = args.hdf5s_dir
|
63 |
+
sample_rate = args.sample_rate
|
64 |
+
channels = args.channels
|
65 |
+
mono = True if channels == 1 else False
|
66 |
+
|
67 |
+
# Only pack data for training data.
|
68 |
+
assert split == "train"
|
69 |
+
|
70 |
+
# paths
|
71 |
+
audios_dir = os.path.join(dataset_dir, 'mp3s')
|
72 |
+
meta_csv = os.path.join(dataset_dir, 'validation.csv')
|
73 |
+
|
74 |
+
os.makedirs(hdf5s_dir, exist_ok=True)
|
75 |
+
|
76 |
+
# Read train & test names.
|
77 |
+
names_dict = read_csv(meta_csv)
|
78 |
+
|
79 |
+
audio_names = names_dict[split]
|
80 |
+
|
81 |
+
params = []
|
82 |
+
|
83 |
+
for audio_index, audio_name in enumerate(audio_names):
|
84 |
+
|
85 |
+
audio_path = os.path.join(audios_dir, audio_name)
|
86 |
+
|
87 |
+
hdf5_path = os.path.join(
|
88 |
+
hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
|
89 |
+
)
|
90 |
+
|
91 |
+
param = (
|
92 |
+
audio_index,
|
93 |
+
audio_name,
|
94 |
+
source_type,
|
95 |
+
audio_path,
|
96 |
+
mono,
|
97 |
+
sample_rate,
|
98 |
+
hdf5_path,
|
99 |
+
)
|
100 |
+
params.append(param)
|
101 |
+
|
102 |
+
# Uncomment for debug.
|
103 |
+
# write_single_audio_to_hdf5(params[0])
|
104 |
+
# os._exit()
|
105 |
+
|
106 |
+
pack_hdf5s_time = time.time()
|
107 |
+
|
108 |
+
with ProcessPoolExecutor(max_workers=None) as pool:
|
109 |
+
# Maximum works on the machine
|
110 |
+
pool.map(write_single_audio_to_hdf5, params)
|
111 |
+
|
112 |
+
print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
|
113 |
+
|
114 |
+
|
115 |
+
def write_single_audio_to_hdf5(param: List) -> NoReturn:
|
116 |
+
r"""Write single audio into hdf5 file."""
|
117 |
+
|
118 |
+
(
|
119 |
+
audio_index,
|
120 |
+
audio_name,
|
121 |
+
source_type,
|
122 |
+
audio_path,
|
123 |
+
mono,
|
124 |
+
sample_rate,
|
125 |
+
hdf5_path,
|
126 |
+
) = param
|
127 |
+
|
128 |
+
with h5py.File(hdf5_path, "w") as hf:
|
129 |
+
|
130 |
+
hf.attrs.create("audio_name", data=audio_name, dtype="S100")
|
131 |
+
hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
|
132 |
+
|
133 |
+
audio = load_audio(audio_path=audio_path, mono=mono, sample_rate=sample_rate)
|
134 |
+
# audio: (channels_num, audio_samples)
|
135 |
+
|
136 |
+
hf.create_dataset(
|
137 |
+
name=source_type, data=float32_to_int16(audio), dtype=np.int16
|
138 |
+
)
|
139 |
+
|
140 |
+
print('{} Write hdf5 to {}'.format(audio_index, hdf5_path))
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
parser = argparse.ArgumentParser()
|
145 |
+
|
146 |
+
parser.add_argument(
|
147 |
+
"--dataset_dir",
|
148 |
+
type=str,
|
149 |
+
required=True,
|
150 |
+
help="Directory of the instruments solo dataset.",
|
151 |
+
)
|
152 |
+
parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
|
153 |
+
parser.add_argument(
|
154 |
+
"--source_type",
|
155 |
+
type=str,
|
156 |
+
required=True,
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--hdf5s_dir",
|
160 |
+
type=str,
|
161 |
+
required=True,
|
162 |
+
help="Directory to write out hdf5 files.",
|
163 |
+
)
|
164 |
+
parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
|
165 |
+
parser.add_argument(
|
166 |
+
"--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
|
167 |
+
)
|
168 |
+
|
169 |
+
# Parse arguments.
|
170 |
+
args = parser.parse_args()
|
171 |
+
|
172 |
+
# Pack audios to hdf5 files.
|
173 |
+
pack_audios_to_hdf5s(args)
|
bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import time
|
5 |
+
from concurrent.futures import ProcessPoolExecutor
|
6 |
+
from typing import Dict, NoReturn
|
7 |
+
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
|
11 |
+
write_single_audio_to_hdf5,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def read_csv(meta_csv) -> Dict:
|
16 |
+
r"""Get train & test names from csv.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
meta_csv: str
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
names_dict: dict, e.g., {
|
23 |
+
'train', ['a1.mp3', 'a2.mp3'],
|
24 |
+
'test': ['b1.mp3', 'b2.mp3']
|
25 |
+
}
|
26 |
+
"""
|
27 |
+
df = pd.read_csv(meta_csv, sep=',')
|
28 |
+
|
29 |
+
names_dict = {}
|
30 |
+
|
31 |
+
for split in ['train', 'test']:
|
32 |
+
audio_indexes = df['split'] == split
|
33 |
+
audio_names = list(df['audio_filename'][audio_indexes])
|
34 |
+
names_dict[split] = audio_names
|
35 |
+
|
36 |
+
return names_dict
|
37 |
+
|
38 |
+
|
39 |
+
def pack_audios_to_hdf5s(args) -> NoReturn:
|
40 |
+
r"""Pack (resampled) audio files into hdf5 files to speed up loading.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
dataset_dir: str
|
44 |
+
split: str, 'train' | 'test'
|
45 |
+
hdf5s_dir: str, directory to write out hdf5 files
|
46 |
+
sample_rate: int
|
47 |
+
channels_num: int
|
48 |
+
mono: bool
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
NoReturn
|
52 |
+
"""
|
53 |
+
|
54 |
+
# arguments & parameters
|
55 |
+
dataset_dir = args.dataset_dir
|
56 |
+
split = args.split
|
57 |
+
hdf5s_dir = args.hdf5s_dir
|
58 |
+
sample_rate = args.sample_rate
|
59 |
+
channels = args.channels
|
60 |
+
mono = True if channels == 1 else False
|
61 |
+
|
62 |
+
source_type = "piano"
|
63 |
+
|
64 |
+
# Only pack data for training data.
|
65 |
+
assert split == "train"
|
66 |
+
|
67 |
+
# paths
|
68 |
+
meta_csv = os.path.join(dataset_dir, 'maestro-v2.0.0.csv')
|
69 |
+
|
70 |
+
os.makedirs(hdf5s_dir, exist_ok=True)
|
71 |
+
|
72 |
+
# Read train & test names.
|
73 |
+
names_dict = read_csv(meta_csv)
|
74 |
+
|
75 |
+
audio_names = names_dict['{}'.format(split)]
|
76 |
+
|
77 |
+
params = []
|
78 |
+
|
79 |
+
for audio_index, audio_name in enumerate(audio_names):
|
80 |
+
|
81 |
+
audio_path = os.path.join(dataset_dir, audio_name)
|
82 |
+
|
83 |
+
hdf5_path = os.path.join(
|
84 |
+
hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
|
85 |
+
)
|
86 |
+
|
87 |
+
param = (
|
88 |
+
audio_index,
|
89 |
+
audio_name,
|
90 |
+
source_type,
|
91 |
+
audio_path,
|
92 |
+
mono,
|
93 |
+
sample_rate,
|
94 |
+
hdf5_path,
|
95 |
+
)
|
96 |
+
params.append(param)
|
97 |
+
|
98 |
+
# Uncomment for debug.
|
99 |
+
# write_single_audio_to_hdf5(params[0])
|
100 |
+
# os._exit(0)
|
101 |
+
|
102 |
+
pack_hdf5s_time = time.time()
|
103 |
+
|
104 |
+
with ProcessPoolExecutor(max_workers=None) as pool:
|
105 |
+
# Maximum works on the machine
|
106 |
+
pool.map(write_single_audio_to_hdf5, params)
|
107 |
+
|
108 |
+
print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
parser = argparse.ArgumentParser()
|
113 |
+
|
114 |
+
parser.add_argument(
|
115 |
+
"--dataset_dir",
|
116 |
+
type=str,
|
117 |
+
required=True,
|
118 |
+
help="Directory of the MAESTRO dataset.",
|
119 |
+
)
|
120 |
+
parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
|
121 |
+
parser.add_argument(
|
122 |
+
"--hdf5s_dir",
|
123 |
+
type=str,
|
124 |
+
required=True,
|
125 |
+
help="Directory to write out hdf5 files.",
|
126 |
+
)
|
127 |
+
parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
|
128 |
+
parser.add_argument(
|
129 |
+
"--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
|
130 |
+
)
|
131 |
+
|
132 |
+
# Parse arguments.
|
133 |
+
args = parser.parse_args()
|
134 |
+
|
135 |
+
# Pack audios to hdf5 files.
|
136 |
+
pack_audios_to_hdf5s(args)
|
bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from concurrent.futures import ProcessPoolExecutor
|
5 |
+
from typing import NoReturn
|
6 |
+
|
7 |
+
import h5py
|
8 |
+
import librosa
|
9 |
+
import musdb
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from bytesep.utils import float32_to_int16
|
13 |
+
|
14 |
+
# Source types of the MUSDB18 dataset.
|
15 |
+
SOURCE_TYPES = ["vocals", "drums", "bass", "other", "accompaniment"]
|
16 |
+
|
17 |
+
|
18 |
+
def pack_audios_to_hdf5s(args) -> NoReturn:
|
19 |
+
r"""Pack (resampled) audio files into hdf5 files to speed up loading.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
dataset_dir: str
|
23 |
+
subset: str, 'train' | 'test'
|
24 |
+
split: str, '' | 'train' | 'valid'
|
25 |
+
hdf5s_dir: str, directory to write out hdf5 files
|
26 |
+
sample_rate: int
|
27 |
+
channels_num: int
|
28 |
+
mono: bool
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
NoReturn
|
32 |
+
"""
|
33 |
+
|
34 |
+
# arguments & parameters
|
35 |
+
dataset_dir = args.dataset_dir
|
36 |
+
subset = args.subset
|
37 |
+
split = None if args.split == "" else args.split
|
38 |
+
hdf5s_dir = args.hdf5s_dir
|
39 |
+
sample_rate = args.sample_rate
|
40 |
+
channels = args.channels
|
41 |
+
|
42 |
+
mono = True if channels == 1 else False
|
43 |
+
source_types = SOURCE_TYPES
|
44 |
+
resample_type = "kaiser_fast"
|
45 |
+
|
46 |
+
# Paths
|
47 |
+
os.makedirs(hdf5s_dir, exist_ok=True)
|
48 |
+
|
49 |
+
# Dataset of corresponding subset and split.
|
50 |
+
mus = musdb.DB(root=dataset_dir, subsets=[subset], split=split)
|
51 |
+
print("Subset: {}, Split: {}, Total pieces: {}".format(subset, split, len(mus)))
|
52 |
+
|
53 |
+
params = [] # A list of params for multiple processing.
|
54 |
+
|
55 |
+
for track_index in range(len(mus.tracks)):
|
56 |
+
|
57 |
+
param = (
|
58 |
+
dataset_dir,
|
59 |
+
subset,
|
60 |
+
split,
|
61 |
+
track_index,
|
62 |
+
source_types,
|
63 |
+
mono,
|
64 |
+
sample_rate,
|
65 |
+
resample_type,
|
66 |
+
hdf5s_dir,
|
67 |
+
)
|
68 |
+
|
69 |
+
params.append(param)
|
70 |
+
|
71 |
+
# Uncomment for debug.
|
72 |
+
# write_single_audio_to_hdf5(params[0])
|
73 |
+
# os._exit(0)
|
74 |
+
|
75 |
+
pack_hdf5s_time = time.time()
|
76 |
+
|
77 |
+
with ProcessPoolExecutor(max_workers=None) as pool:
|
78 |
+
# Maximum works on the machine
|
79 |
+
pool.map(write_single_audio_to_hdf5, params)
|
80 |
+
|
81 |
+
print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
|
82 |
+
|
83 |
+
|
84 |
+
def write_single_audio_to_hdf5(param) -> NoReturn:
|
85 |
+
r"""Write single audio into hdf5 file."""
|
86 |
+
(
|
87 |
+
dataset_dir,
|
88 |
+
subset,
|
89 |
+
split,
|
90 |
+
track_index,
|
91 |
+
source_types,
|
92 |
+
mono,
|
93 |
+
sample_rate,
|
94 |
+
resample_type,
|
95 |
+
hdf5s_dir,
|
96 |
+
) = param
|
97 |
+
|
98 |
+
# Dataset of corresponding subset and split.
|
99 |
+
mus = musdb.DB(root=dataset_dir, subsets=[subset], split=split)
|
100 |
+
track = mus.tracks[track_index]
|
101 |
+
|
102 |
+
# Path to write out hdf5 file.
|
103 |
+
hdf5_path = os.path.join(hdf5s_dir, "{}.h5".format(track.name))
|
104 |
+
|
105 |
+
with h5py.File(hdf5_path, "w") as hf:
|
106 |
+
|
107 |
+
hf.attrs.create("audio_name", data=track.name.encode(), dtype="S100")
|
108 |
+
hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
|
109 |
+
|
110 |
+
for source_type in source_types:
|
111 |
+
|
112 |
+
audio = track.targets[source_type].audio.T
|
113 |
+
# (channels_num, audio_samples)
|
114 |
+
|
115 |
+
# Preprocess audio to mono / stereo, and resample.
|
116 |
+
audio = preprocess_audio(
|
117 |
+
audio, mono, track.rate, sample_rate, resample_type
|
118 |
+
)
|
119 |
+
# audio = load_audio(audio_path=audio_path, mono=mono, sample_rate=sample_rate)
|
120 |
+
# (channels_num, audio_samples) | (audio_samples,)
|
121 |
+
|
122 |
+
hf.create_dataset(
|
123 |
+
name=source_type, data=float32_to_int16(audio), dtype=np.int16
|
124 |
+
)
|
125 |
+
|
126 |
+
# Mixture
|
127 |
+
audio = track.audio.T
|
128 |
+
# (channels_num, audio_samples)
|
129 |
+
|
130 |
+
# Preprocess audio to mono / stereo, and resample.
|
131 |
+
audio = preprocess_audio(audio, mono, track.rate, sample_rate, resample_type)
|
132 |
+
# (channels_num, audio_samples)
|
133 |
+
|
134 |
+
hf.create_dataset(name="mixture", data=float32_to_int16(audio), dtype=np.int16)
|
135 |
+
|
136 |
+
print("{} Write to {}, {}".format(track_index, hdf5_path, audio.shape))
|
137 |
+
|
138 |
+
|
139 |
+
def preprocess_audio(audio, mono, origin_sr, sr, resample_type) -> np.array:
|
140 |
+
r"""Preprocess audio to mono / stereo, and resample.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
audio: (channels_num, audio_samples), input audio
|
144 |
+
mono: bool
|
145 |
+
origin_sr: float, original sample rate
|
146 |
+
sr: float, target sample rate
|
147 |
+
resample_type: str, e.g., 'kaiser_fast'
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
output: ndarray, output audio
|
151 |
+
"""
|
152 |
+
if mono:
|
153 |
+
audio = np.mean(audio, axis=0)
|
154 |
+
# (audio_samples,)
|
155 |
+
|
156 |
+
output = librosa.core.resample(
|
157 |
+
audio, orig_sr=origin_sr, target_sr=sr, res_type=resample_type
|
158 |
+
)
|
159 |
+
# (audio_samples,) | (channels_num, audio_samples)
|
160 |
+
|
161 |
+
if output.ndim == 1:
|
162 |
+
output = output[None, :]
|
163 |
+
# (1, audio_samples,)
|
164 |
+
|
165 |
+
return output
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
parser = argparse.ArgumentParser()
|
170 |
+
|
171 |
+
parser.add_argument(
|
172 |
+
"--dataset_dir",
|
173 |
+
type=str,
|
174 |
+
required=True,
|
175 |
+
help="Directory of the MUSDB18 dataset.",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--subset",
|
179 |
+
type=str,
|
180 |
+
required=True,
|
181 |
+
choices=["train", "test"],
|
182 |
+
help="Train subset: 100 pieces; test subset: 50 pieces.",
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--split",
|
186 |
+
type=str,
|
187 |
+
required=True,
|
188 |
+
choices=["", "train", "valid"],
|
189 |
+
help="Use '' to use all 100 pieces to train. Use 'train' to use 86 \
|
190 |
+
pieces for train, and use 'test' to use 14 pieces for valid.",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--hdf5s_dir",
|
194 |
+
type=str,
|
195 |
+
required=True,
|
196 |
+
help="Directory to write out hdf5 files.",
|
197 |
+
)
|
198 |
+
parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
|
199 |
+
parser.add_argument(
|
200 |
+
"--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
|
201 |
+
)
|
202 |
+
|
203 |
+
# Parse arguments.
|
204 |
+
args = parser.parse_args()
|
205 |
+
|
206 |
+
# Pack audios into hdf5 files.
|
207 |
+
pack_audios_to_hdf5s(args)
|
bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import time
|
5 |
+
from concurrent.futures import ProcessPoolExecutor
|
6 |
+
from typing import NoReturn
|
7 |
+
|
8 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
|
9 |
+
write_single_audio_to_hdf5,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def pack_audios_to_hdf5s(args) -> NoReturn:
|
14 |
+
r"""Pack (resampled) audio files into hdf5 files to speed up loading.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
dataset_dir: str
|
18 |
+
split: str, 'train' | 'test'
|
19 |
+
hdf5s_dir: str, directory to write out hdf5 files
|
20 |
+
sample_rate: int
|
21 |
+
channels_num: int
|
22 |
+
mono: bool
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
NoReturn
|
26 |
+
"""
|
27 |
+
|
28 |
+
# arguments & parameters
|
29 |
+
dataset_dir = args.dataset_dir
|
30 |
+
split = args.split
|
31 |
+
hdf5s_dir = args.hdf5s_dir
|
32 |
+
sample_rate = args.sample_rate
|
33 |
+
channels = args.channels
|
34 |
+
mono = True if channels == 1 else False
|
35 |
+
|
36 |
+
source_type = "speech"
|
37 |
+
|
38 |
+
# Only pack data for training data.
|
39 |
+
assert split == "train"
|
40 |
+
|
41 |
+
audios_dir = os.path.join(dataset_dir, 'wav48', split)
|
42 |
+
os.makedirs(hdf5s_dir, exist_ok=True)
|
43 |
+
|
44 |
+
speaker_ids = sorted(os.listdir(audios_dir))
|
45 |
+
|
46 |
+
params = []
|
47 |
+
audio_index = 0
|
48 |
+
|
49 |
+
for speaker_id in speaker_ids:
|
50 |
+
|
51 |
+
speaker_audios_dir = os.path.join(audios_dir, speaker_id)
|
52 |
+
|
53 |
+
audio_names = sorted(os.listdir(speaker_audios_dir))
|
54 |
+
|
55 |
+
for audio_name in audio_names:
|
56 |
+
|
57 |
+
audio_path = os.path.join(speaker_audios_dir, audio_name)
|
58 |
+
|
59 |
+
hdf5_path = os.path.join(
|
60 |
+
hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
|
61 |
+
)
|
62 |
+
|
63 |
+
param = (
|
64 |
+
audio_index,
|
65 |
+
audio_name,
|
66 |
+
source_type,
|
67 |
+
audio_path,
|
68 |
+
mono,
|
69 |
+
sample_rate,
|
70 |
+
hdf5_path,
|
71 |
+
)
|
72 |
+
params.append(param)
|
73 |
+
|
74 |
+
audio_index += 1
|
75 |
+
|
76 |
+
# Uncomment for debug.
|
77 |
+
# write_single_audio_to_hdf5(params[0])
|
78 |
+
# os._exit(0)
|
79 |
+
|
80 |
+
pack_hdf5s_time = time.time()
|
81 |
+
|
82 |
+
with ProcessPoolExecutor(max_workers=None) as pool:
|
83 |
+
# Maximum works on the machine
|
84 |
+
pool.map(write_single_audio_to_hdf5, params)
|
85 |
+
|
86 |
+
print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
parser = argparse.ArgumentParser()
|
91 |
+
|
92 |
+
parser.add_argument(
|
93 |
+
"--dataset_dir",
|
94 |
+
type=str,
|
95 |
+
required=True,
|
96 |
+
help="Directory of the VCTK dataset.",
|
97 |
+
)
|
98 |
+
parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
|
99 |
+
parser.add_argument(
|
100 |
+
"--hdf5s_dir",
|
101 |
+
type=str,
|
102 |
+
required=True,
|
103 |
+
help="Directory to write out hdf5 files.",
|
104 |
+
)
|
105 |
+
parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
|
106 |
+
parser.add_argument(
|
107 |
+
"--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
|
108 |
+
)
|
109 |
+
|
110 |
+
# Parse arguments.
|
111 |
+
args = parser.parse_args()
|
112 |
+
|
113 |
+
# Pack audios into hdf5 files.
|
114 |
+
pack_audios_to_hdf5s(args)
|
bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import time
|
5 |
+
from concurrent.futures import ProcessPoolExecutor
|
6 |
+
from typing import List, NoReturn
|
7 |
+
|
8 |
+
import h5py
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from bytesep.utils import float32_to_int16, load_audio
|
12 |
+
|
13 |
+
|
14 |
+
def pack_audios_to_hdf5s(args) -> NoReturn:
|
15 |
+
r"""Pack (resampled) audio files into hdf5 files to speed up loading.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
dataset_dir: str
|
19 |
+
split: str, 'train' | 'test'
|
20 |
+
hdf5s_dir: str, directory to write out hdf5 files
|
21 |
+
sample_rate: int
|
22 |
+
channels_num: int
|
23 |
+
mono: bool
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
NoReturn
|
27 |
+
"""
|
28 |
+
|
29 |
+
# arguments & parameters
|
30 |
+
dataset_dir = args.dataset_dir
|
31 |
+
split = args.split
|
32 |
+
hdf5s_dir = args.hdf5s_dir
|
33 |
+
sample_rate = args.sample_rate
|
34 |
+
channels = args.channels
|
35 |
+
mono = True if channels == 1 else False
|
36 |
+
|
37 |
+
# Only pack data for training data.
|
38 |
+
assert split == "train"
|
39 |
+
|
40 |
+
speech_dir = os.path.join(dataset_dir, "clean_{}set_wav".format(split))
|
41 |
+
mixture_dir = os.path.join(dataset_dir, "noisy_{}set_wav".format(split))
|
42 |
+
|
43 |
+
os.makedirs(hdf5s_dir, exist_ok=True)
|
44 |
+
|
45 |
+
# Read names.
|
46 |
+
audio_names = sorted(os.listdir(speech_dir))
|
47 |
+
|
48 |
+
params = []
|
49 |
+
|
50 |
+
for audio_index, audio_name in enumerate(audio_names):
|
51 |
+
|
52 |
+
speech_path = os.path.join(speech_dir, audio_name)
|
53 |
+
mixture_path = os.path.join(mixture_dir, audio_name)
|
54 |
+
|
55 |
+
hdf5_path = os.path.join(
|
56 |
+
hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
|
57 |
+
)
|
58 |
+
|
59 |
+
param = (
|
60 |
+
audio_index,
|
61 |
+
audio_name,
|
62 |
+
speech_path,
|
63 |
+
mixture_path,
|
64 |
+
mono,
|
65 |
+
sample_rate,
|
66 |
+
hdf5_path,
|
67 |
+
)
|
68 |
+
params.append(param)
|
69 |
+
|
70 |
+
# Uncomment for debug.
|
71 |
+
# write_single_audio_to_hdf5(params[0])
|
72 |
+
# os._exit(0)
|
73 |
+
|
74 |
+
pack_hdf5s_time = time.time()
|
75 |
+
|
76 |
+
with ProcessPoolExecutor(max_workers=None) as pool:
|
77 |
+
# Maximum works on the machine
|
78 |
+
pool.map(write_single_audio_to_hdf5, params)
|
79 |
+
|
80 |
+
print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
|
81 |
+
|
82 |
+
|
83 |
+
def write_single_audio_to_hdf5(param: List) -> NoReturn:
|
84 |
+
r"""Write single audio into hdf5 file."""
|
85 |
+
|
86 |
+
(
|
87 |
+
audio_index,
|
88 |
+
audio_name,
|
89 |
+
speech_path,
|
90 |
+
mixture_path,
|
91 |
+
mono,
|
92 |
+
sample_rate,
|
93 |
+
hdf5_path,
|
94 |
+
) = param
|
95 |
+
|
96 |
+
with h5py.File(hdf5_path, "w") as hf:
|
97 |
+
|
98 |
+
hf.attrs.create("audio_name", data=audio_name, dtype="S100")
|
99 |
+
hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
|
100 |
+
|
101 |
+
speech = load_audio(audio_path=speech_path, mono=mono, sample_rate=sample_rate)
|
102 |
+
# speech: (channels_num, audio_samples)
|
103 |
+
|
104 |
+
mixture = load_audio(
|
105 |
+
audio_path=mixture_path, mono=mono, sample_rate=sample_rate
|
106 |
+
)
|
107 |
+
# mixture: (channels_num, audio_samples)
|
108 |
+
|
109 |
+
noise = mixture - speech
|
110 |
+
# noise: (channels_num, audio_samples)
|
111 |
+
|
112 |
+
hf.create_dataset(name='speech', data=float32_to_int16(speech), dtype=np.int16)
|
113 |
+
hf.create_dataset(name='noise', data=float32_to_int16(noise), dtype=np.int16)
|
114 |
+
|
115 |
+
print('{} Write hdf5 to {}'.format(audio_index, hdf5_path))
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
parser = argparse.ArgumentParser()
|
120 |
+
|
121 |
+
parser.add_argument(
|
122 |
+
"--dataset_dir",
|
123 |
+
type=str,
|
124 |
+
required=True,
|
125 |
+
help="Directory of the Voicebank-Demand dataset.",
|
126 |
+
)
|
127 |
+
parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
|
128 |
+
parser.add_argument(
|
129 |
+
"--hdf5s_dir",
|
130 |
+
type=str,
|
131 |
+
required=True,
|
132 |
+
help="Directory to write out hdf5 files.",
|
133 |
+
)
|
134 |
+
parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
|
135 |
+
parser.add_argument(
|
136 |
+
"--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
|
137 |
+
)
|
138 |
+
|
139 |
+
# Parse arguments.
|
140 |
+
args = parser.parse_args()
|
141 |
+
|
142 |
+
# Pack audios into hdf5 files.
|
143 |
+
pack_audios_to_hdf5s(args)
|
bytesep/inference.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('.')
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
from typing import Dict
|
7 |
+
import pathlib
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import soundfile
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
from bytesep.models.lightning_modules import get_model_class
|
16 |
+
from bytesep.utils import read_yaml
|
17 |
+
|
18 |
+
|
19 |
+
class Separator:
|
20 |
+
def __init__(
|
21 |
+
self, model: nn.Module, segment_samples: int, batch_size: int, device: str
|
22 |
+
):
|
23 |
+
r"""Separate to separate an audio clip into a target source.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
model: nn.Module, trained model
|
27 |
+
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
|
28 |
+
batch_size, int, e.g., 12
|
29 |
+
device: str, e.g., 'cuda'
|
30 |
+
"""
|
31 |
+
self.model = model
|
32 |
+
self.segment_samples = segment_samples
|
33 |
+
self.batch_size = batch_size
|
34 |
+
self.device = device
|
35 |
+
|
36 |
+
def separate(self, input_dict: Dict) -> np.array:
|
37 |
+
r"""Separate an audio clip into a target source.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
input_dict: dict, e.g., {
|
41 |
+
waveform: (channels_num, audio_samples),
|
42 |
+
...,
|
43 |
+
}
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
sep_audio: (channels_num, audio_samples) | (target_sources_num, channels_num, audio_samples)
|
47 |
+
"""
|
48 |
+
audio = input_dict['waveform']
|
49 |
+
|
50 |
+
audio_samples = audio.shape[-1]
|
51 |
+
|
52 |
+
# Pad the audio with zero in the end so that the length of audio can be
|
53 |
+
# evenly divided by segment_samples.
|
54 |
+
audio = self.pad_audio(audio)
|
55 |
+
|
56 |
+
# Enframe long audio into segments.
|
57 |
+
segments = self.enframe(audio, self.segment_samples)
|
58 |
+
# (segments_num, channels_num, segment_samples)
|
59 |
+
|
60 |
+
segments_input_dict = {'waveform': segments}
|
61 |
+
|
62 |
+
if 'condition' in input_dict.keys():
|
63 |
+
segments_num = len(segments)
|
64 |
+
segments_input_dict['condition'] = np.tile(
|
65 |
+
input_dict['condition'][None, :], (segments_num, 1)
|
66 |
+
)
|
67 |
+
# (batch_size, segments_num)
|
68 |
+
|
69 |
+
# Separate in mini-batches.
|
70 |
+
sep_segments = self._forward_in_mini_batches(
|
71 |
+
self.model, segments_input_dict, self.batch_size
|
72 |
+
)['waveform']
|
73 |
+
# (segments_num, channels_num, segment_samples)
|
74 |
+
|
75 |
+
# Deframe segments into long audio.
|
76 |
+
sep_audio = self.deframe(sep_segments)
|
77 |
+
# (channels_num, padded_audio_samples)
|
78 |
+
|
79 |
+
sep_audio = sep_audio[:, 0:audio_samples]
|
80 |
+
# (channels_num, audio_samples)
|
81 |
+
|
82 |
+
return sep_audio
|
83 |
+
|
84 |
+
def pad_audio(self, audio: np.array) -> np.array:
|
85 |
+
r"""Pad the audio with zero in the end so that the length of audio can
|
86 |
+
be evenly divided by segment_samples.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
audio: (channels_num, audio_samples)
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
padded_audio: (channels_num, audio_samples)
|
93 |
+
"""
|
94 |
+
channels_num, audio_samples = audio.shape
|
95 |
+
|
96 |
+
# Number of segments
|
97 |
+
segments_num = int(np.ceil(audio_samples / self.segment_samples))
|
98 |
+
|
99 |
+
pad_samples = segments_num * self.segment_samples - audio_samples
|
100 |
+
|
101 |
+
padded_audio = np.concatenate(
|
102 |
+
(audio, np.zeros((channels_num, pad_samples))), axis=1
|
103 |
+
)
|
104 |
+
# (channels_num, padded_audio_samples)
|
105 |
+
|
106 |
+
return padded_audio
|
107 |
+
|
108 |
+
def enframe(self, audio: np.array, segment_samples: int) -> np.array:
|
109 |
+
r"""Enframe long audio into segments.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
audio: (channels_num, audio_samples)
|
113 |
+
segment_samples: int
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
segments: (segments_num, channels_num, segment_samples)
|
117 |
+
"""
|
118 |
+
audio_samples = audio.shape[1]
|
119 |
+
assert audio_samples % segment_samples == 0
|
120 |
+
|
121 |
+
hop_samples = segment_samples // 2
|
122 |
+
segments = []
|
123 |
+
|
124 |
+
pointer = 0
|
125 |
+
while pointer + segment_samples <= audio_samples:
|
126 |
+
segments.append(audio[:, pointer : pointer + segment_samples])
|
127 |
+
pointer += hop_samples
|
128 |
+
|
129 |
+
segments = np.array(segments)
|
130 |
+
|
131 |
+
return segments
|
132 |
+
|
133 |
+
def deframe(self, segments: np.array) -> np.array:
|
134 |
+
r"""Deframe segments into long audio.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
segments: (segments_num, channels_num, segment_samples)
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
output: (channels_num, audio_samples)
|
141 |
+
"""
|
142 |
+
(segments_num, _, segment_samples) = segments.shape
|
143 |
+
|
144 |
+
if segments_num == 1:
|
145 |
+
return segments[0]
|
146 |
+
|
147 |
+
assert self._is_integer(segment_samples * 0.25)
|
148 |
+
assert self._is_integer(segment_samples * 0.75)
|
149 |
+
|
150 |
+
output = []
|
151 |
+
|
152 |
+
output.append(segments[0, :, 0 : int(segment_samples * 0.75)])
|
153 |
+
|
154 |
+
for i in range(1, segments_num - 1):
|
155 |
+
output.append(
|
156 |
+
segments[
|
157 |
+
i, :, int(segment_samples * 0.25) : int(segment_samples * 0.75)
|
158 |
+
]
|
159 |
+
)
|
160 |
+
|
161 |
+
output.append(segments[-1, :, int(segment_samples * 0.25) :])
|
162 |
+
|
163 |
+
output = np.concatenate(output, axis=-1)
|
164 |
+
|
165 |
+
return output
|
166 |
+
|
167 |
+
def _is_integer(self, x: float) -> bool:
|
168 |
+
if x - int(x) < 1e-10:
|
169 |
+
return True
|
170 |
+
else:
|
171 |
+
return False
|
172 |
+
|
173 |
+
def _forward_in_mini_batches(
|
174 |
+
self, model: nn.Module, segments_input_dict: Dict, batch_size: int
|
175 |
+
) -> Dict:
|
176 |
+
r"""Forward data to model in mini-batch.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
model: nn.Module
|
180 |
+
segments_input_dict: dict, e.g., {
|
181 |
+
'waveform': (segments_num, channels_num, segment_samples),
|
182 |
+
...,
|
183 |
+
}
|
184 |
+
batch_size: int
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
output_dict: dict, e.g. {
|
188 |
+
'waveform': (segments_num, channels_num, segment_samples),
|
189 |
+
}
|
190 |
+
"""
|
191 |
+
output_dict = {}
|
192 |
+
|
193 |
+
pointer = 0
|
194 |
+
segments_num = len(segments_input_dict['waveform'])
|
195 |
+
|
196 |
+
while True:
|
197 |
+
if pointer >= segments_num:
|
198 |
+
break
|
199 |
+
|
200 |
+
batch_input_dict = {}
|
201 |
+
|
202 |
+
for key in segments_input_dict.keys():
|
203 |
+
batch_input_dict[key] = torch.Tensor(
|
204 |
+
segments_input_dict[key][pointer : pointer + batch_size]
|
205 |
+
).to(self.device)
|
206 |
+
|
207 |
+
pointer += batch_size
|
208 |
+
|
209 |
+
with torch.no_grad():
|
210 |
+
model.eval()
|
211 |
+
batch_output_dict = model(batch_input_dict)
|
212 |
+
|
213 |
+
for key in batch_output_dict.keys():
|
214 |
+
self._append_to_dict(
|
215 |
+
output_dict, key, batch_output_dict[key].data.cpu().numpy()
|
216 |
+
)
|
217 |
+
|
218 |
+
for key in output_dict.keys():
|
219 |
+
output_dict[key] = np.concatenate(output_dict[key], axis=0)
|
220 |
+
|
221 |
+
return output_dict
|
222 |
+
|
223 |
+
def _append_to_dict(self, dict, key, value):
|
224 |
+
if key in dict.keys():
|
225 |
+
dict[key].append(value)
|
226 |
+
else:
|
227 |
+
dict[key] = [value]
|
228 |
+
|
229 |
+
|
230 |
+
class SeparatorWrapper:
|
231 |
+
def __init__(
|
232 |
+
self, source_type='vocals', model=None, checkpoint_path=None, device='cuda'
|
233 |
+
):
|
234 |
+
|
235 |
+
input_channels = 2
|
236 |
+
target_sources_num = 1
|
237 |
+
model_type = "ResUNet143_Subbandtime"
|
238 |
+
segment_samples = 44100 * 10
|
239 |
+
batch_size = 1
|
240 |
+
|
241 |
+
self.checkpoint_path = self.download_checkpoints(checkpoint_path, source_type)
|
242 |
+
|
243 |
+
if device == 'cuda' and torch.cuda.is_available():
|
244 |
+
self.device = 'cuda'
|
245 |
+
else:
|
246 |
+
self.device = 'cpu'
|
247 |
+
|
248 |
+
# Get model class.
|
249 |
+
Model = get_model_class(model_type)
|
250 |
+
|
251 |
+
# Create model.
|
252 |
+
self.model = Model(
|
253 |
+
input_channels=input_channels, target_sources_num=target_sources_num
|
254 |
+
)
|
255 |
+
|
256 |
+
# Load checkpoint.
|
257 |
+
checkpoint = torch.load(self.checkpoint_path, map_location='cpu')
|
258 |
+
self.model.load_state_dict(checkpoint["model"])
|
259 |
+
|
260 |
+
# Move model to device.
|
261 |
+
self.model.to(self.device)
|
262 |
+
|
263 |
+
# Create separator.
|
264 |
+
self.separator = Separator(
|
265 |
+
model=self.model,
|
266 |
+
segment_samples=segment_samples,
|
267 |
+
batch_size=batch_size,
|
268 |
+
device=self.device,
|
269 |
+
)
|
270 |
+
|
271 |
+
def download_checkpoints(self, checkpoint_path, source_type):
|
272 |
+
|
273 |
+
if source_type == "vocals":
|
274 |
+
checkpoint_bare_name = "resunet143_subbtandtime_vocals_8.8dB_350k_steps"
|
275 |
+
|
276 |
+
elif source_type == "accompaniment":
|
277 |
+
checkpoint_bare_name = (
|
278 |
+
"resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth"
|
279 |
+
)
|
280 |
+
|
281 |
+
else:
|
282 |
+
raise NotImplementedError
|
283 |
+
|
284 |
+
if not checkpoint_path:
|
285 |
+
checkpoint_path = '{}/bytesep_data/{}.pth'.format(
|
286 |
+
str(pathlib.Path.home()), checkpoint_bare_name
|
287 |
+
)
|
288 |
+
|
289 |
+
print('Checkpoint path: {}'.format(checkpoint_path))
|
290 |
+
|
291 |
+
if (
|
292 |
+
not os.path.exists(checkpoint_path)
|
293 |
+
or os.path.getsize(checkpoint_path) < 4e8
|
294 |
+
):
|
295 |
+
|
296 |
+
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
297 |
+
|
298 |
+
zenodo_dir = "https://zenodo.org/record/5507029/files"
|
299 |
+
zenodo_path = os.path.join(
|
300 |
+
zenodo_dir, "{}?download=1".format(checkpoint_bare_name)
|
301 |
+
)
|
302 |
+
|
303 |
+
os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))
|
304 |
+
|
305 |
+
return checkpoint_path
|
306 |
+
|
307 |
+
def separate(self, audio):
|
308 |
+
|
309 |
+
input_dict = {'waveform': audio}
|
310 |
+
|
311 |
+
sep_wav = self.separator.separate(input_dict)
|
312 |
+
|
313 |
+
return sep_wav
|
314 |
+
|
315 |
+
|
316 |
+
def inference(args):
|
317 |
+
|
318 |
+
# Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync.
|
319 |
+
import torch.distributed as dist
|
320 |
+
|
321 |
+
dist.init_process_group(
|
322 |
+
'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1
|
323 |
+
)
|
324 |
+
|
325 |
+
# Arguments & parameters
|
326 |
+
config_yaml = args.config_yaml
|
327 |
+
checkpoint_path = args.checkpoint_path
|
328 |
+
audio_path = args.audio_path
|
329 |
+
output_path = args.output_path
|
330 |
+
device = (
|
331 |
+
torch.device('cuda')
|
332 |
+
if args.cuda and torch.cuda.is_available()
|
333 |
+
else torch.device('cpu')
|
334 |
+
)
|
335 |
+
|
336 |
+
configs = read_yaml(config_yaml)
|
337 |
+
sample_rate = configs['train']['sample_rate']
|
338 |
+
input_channels = configs['train']['channels']
|
339 |
+
target_source_types = configs['train']['target_source_types']
|
340 |
+
target_sources_num = len(target_source_types)
|
341 |
+
model_type = configs['train']['model_type']
|
342 |
+
|
343 |
+
segment_samples = int(30 * sample_rate)
|
344 |
+
batch_size = 1
|
345 |
+
|
346 |
+
print("Using {} for separating ..".format(device))
|
347 |
+
|
348 |
+
# paths
|
349 |
+
if os.path.dirname(output_path) != "":
|
350 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
351 |
+
|
352 |
+
# Get model class.
|
353 |
+
Model = get_model_class(model_type)
|
354 |
+
|
355 |
+
# Create model.
|
356 |
+
model = Model(input_channels=input_channels, target_sources_num=target_sources_num)
|
357 |
+
|
358 |
+
# Load checkpoint.
|
359 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
360 |
+
model.load_state_dict(checkpoint["model"])
|
361 |
+
|
362 |
+
# Move model to device.
|
363 |
+
model.to(device)
|
364 |
+
|
365 |
+
# Create separator.
|
366 |
+
separator = Separator(
|
367 |
+
model=model,
|
368 |
+
segment_samples=segment_samples,
|
369 |
+
batch_size=batch_size,
|
370 |
+
device=device,
|
371 |
+
)
|
372 |
+
|
373 |
+
# Load audio.
|
374 |
+
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=False)
|
375 |
+
|
376 |
+
# audio = audio[None, :]
|
377 |
+
|
378 |
+
input_dict = {'waveform': audio}
|
379 |
+
|
380 |
+
# Separate
|
381 |
+
separate_time = time.time()
|
382 |
+
|
383 |
+
sep_wav = separator.separate(input_dict)
|
384 |
+
# (channels_num, audio_samples)
|
385 |
+
|
386 |
+
print('Separate time: {:.3f} s'.format(time.time() - separate_time))
|
387 |
+
|
388 |
+
# Write out separated audio.
|
389 |
+
soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate)
|
390 |
+
os.system("ffmpeg -y -loglevel panic -i _zz.wav {}".format(output_path))
|
391 |
+
print('Write out to {}'.format(output_path))
|
392 |
+
|
393 |
+
|
394 |
+
if __name__ == "__main__":
|
395 |
+
|
396 |
+
parser = argparse.ArgumentParser(description="")
|
397 |
+
parser.add_argument("--config_yaml", type=str, required=True)
|
398 |
+
parser.add_argument("--checkpoint_path", type=str, required=True)
|
399 |
+
parser.add_argument("--audio_path", type=str, required=True)
|
400 |
+
parser.add_argument("--output_path", type=str, required=True)
|
401 |
+
parser.add_argument("--cuda", action='store_true', default=True)
|
402 |
+
|
403 |
+
args = parser.parse_args()
|
404 |
+
inference(args)
|
bytesep/inference_many.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import time
|
5 |
+
from typing import NoReturn
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import soundfile
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from bytesep.inference import Separator
|
13 |
+
from bytesep.models.lightning_modules import get_model_class
|
14 |
+
from bytesep.utils import read_yaml
|
15 |
+
|
16 |
+
|
17 |
+
def inference(args) -> NoReturn:
|
18 |
+
r"""Separate all audios in a directory.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
config_yaml: str, the config file of a model being trained
|
22 |
+
checkpoint_path: str, the path of checkpoint to be loaded
|
23 |
+
audios_dir: str, the directory of audios to be separated
|
24 |
+
output_dir: str, the directory to write out separated audios
|
25 |
+
scale_volume: bool, if True then the volume is scaled to the maximum value of 1.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
NoReturn
|
29 |
+
"""
|
30 |
+
|
31 |
+
# Arguments & parameters
|
32 |
+
config_yaml = args.config_yaml
|
33 |
+
checkpoint_path = args.checkpoint_path
|
34 |
+
audios_dir = args.audios_dir
|
35 |
+
output_dir = args.output_dir
|
36 |
+
scale_volume = args.scale_volume
|
37 |
+
device = (
|
38 |
+
torch.device('cuda')
|
39 |
+
if args.cuda and torch.cuda.is_available()
|
40 |
+
else torch.device('cpu')
|
41 |
+
)
|
42 |
+
|
43 |
+
configs = read_yaml(config_yaml)
|
44 |
+
sample_rate = configs['train']['sample_rate']
|
45 |
+
input_channels = configs['train']['channels']
|
46 |
+
target_source_types = configs['train']['target_source_types']
|
47 |
+
target_sources_num = len(target_source_types)
|
48 |
+
model_type = configs['train']['model_type']
|
49 |
+
mono = input_channels == 1
|
50 |
+
|
51 |
+
segment_samples = int(30 * sample_rate)
|
52 |
+
batch_size = 1
|
53 |
+
device = "cuda"
|
54 |
+
|
55 |
+
models_contains_inplaceabn = True
|
56 |
+
|
57 |
+
# Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync.
|
58 |
+
if models_contains_inplaceabn:
|
59 |
+
|
60 |
+
import torch.distributed as dist
|
61 |
+
|
62 |
+
dist.init_process_group(
|
63 |
+
'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1
|
64 |
+
)
|
65 |
+
|
66 |
+
print("Using {} for separating ..".format(device))
|
67 |
+
|
68 |
+
# paths
|
69 |
+
os.makedirs(output_dir, exist_ok=True)
|
70 |
+
|
71 |
+
# Get model class.
|
72 |
+
Model = get_model_class(model_type)
|
73 |
+
|
74 |
+
# Create model.
|
75 |
+
model = Model(input_channels=input_channels, target_sources_num=target_sources_num)
|
76 |
+
|
77 |
+
# Load checkpoint.
|
78 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
79 |
+
model.load_state_dict(checkpoint["model"])
|
80 |
+
|
81 |
+
# Move model to device.
|
82 |
+
model.to(device)
|
83 |
+
|
84 |
+
# Create separator.
|
85 |
+
separator = Separator(
|
86 |
+
model=model,
|
87 |
+
segment_samples=segment_samples,
|
88 |
+
batch_size=batch_size,
|
89 |
+
device=device,
|
90 |
+
)
|
91 |
+
|
92 |
+
audio_names = sorted(os.listdir(audios_dir))
|
93 |
+
|
94 |
+
for audio_name in audio_names:
|
95 |
+
audio_path = os.path.join(audios_dir, audio_name)
|
96 |
+
|
97 |
+
# Load audio.
|
98 |
+
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=mono)
|
99 |
+
|
100 |
+
if audio.ndim == 1:
|
101 |
+
audio = audio[None, :]
|
102 |
+
|
103 |
+
input_dict = {'waveform': audio}
|
104 |
+
|
105 |
+
# Separate
|
106 |
+
separate_time = time.time()
|
107 |
+
|
108 |
+
sep_wav = separator.separate(input_dict)
|
109 |
+
# (channels_num, audio_samples)
|
110 |
+
|
111 |
+
print('Separate time: {:.3f} s'.format(time.time() - separate_time))
|
112 |
+
|
113 |
+
# Write out separated audio.
|
114 |
+
if scale_volume:
|
115 |
+
sep_wav /= np.max(np.abs(sep_wav))
|
116 |
+
|
117 |
+
soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate)
|
118 |
+
|
119 |
+
output_path = os.path.join(
|
120 |
+
output_dir, '{}.mp3'.format(pathlib.Path(audio_name).stem)
|
121 |
+
)
|
122 |
+
os.system('ffmpeg -y -loglevel panic -i _zz.wav "{}"'.format(output_path))
|
123 |
+
print('Write out to {}'.format(output_path))
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
|
128 |
+
parser = argparse.ArgumentParser(description="")
|
129 |
+
parser.add_argument(
|
130 |
+
"--config_yaml",
|
131 |
+
type=str,
|
132 |
+
required=True,
|
133 |
+
help="The config file of a model being trained.",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--checkpoint_path",
|
137 |
+
type=str,
|
138 |
+
required=True,
|
139 |
+
help="The path of checkpoint to be loaded.",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--audios_dir",
|
143 |
+
type=str,
|
144 |
+
required=True,
|
145 |
+
help="The directory of audios to be separated.",
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"--output_dir",
|
149 |
+
type=str,
|
150 |
+
required=True,
|
151 |
+
help="The directory to write out separated audios.",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
'--scale_volume',
|
155 |
+
action='store_true',
|
156 |
+
default=False,
|
157 |
+
help="set to True if separated audios are scaled to the maximum value of 1.",
|
158 |
+
)
|
159 |
+
parser.add_argument("--cuda", action='store_true', default=True)
|
160 |
+
|
161 |
+
args = parser.parse_args()
|
162 |
+
|
163 |
+
inference(args)
|
bytesep/losses.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchlibrosa.stft import STFT
|
7 |
+
|
8 |
+
from bytesep.models.pytorch_modules import Base
|
9 |
+
|
10 |
+
|
11 |
+
def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
|
12 |
+
r"""L1 loss.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
output: torch.Tensor
|
16 |
+
target: torch.Tensor
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
loss: torch.float
|
20 |
+
"""
|
21 |
+
return torch.mean(torch.abs(output - target))
|
22 |
+
|
23 |
+
|
24 |
+
def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
|
25 |
+
r"""L1 loss in the time-domain.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
output: torch.Tensor
|
29 |
+
target: torch.Tensor
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
loss: torch.float
|
33 |
+
"""
|
34 |
+
return l1(output, target)
|
35 |
+
|
36 |
+
|
37 |
+
class L1_Wav_L1_Sp(nn.Module, Base):
|
38 |
+
def __init__(self):
|
39 |
+
r"""L1 loss in the time-domain and L1 loss on the spectrogram."""
|
40 |
+
super(L1_Wav_L1_Sp, self).__init__()
|
41 |
+
|
42 |
+
self.window_size = 2048
|
43 |
+
hop_size = 441
|
44 |
+
center = True
|
45 |
+
pad_mode = "reflect"
|
46 |
+
window = "hann"
|
47 |
+
|
48 |
+
self.stft = STFT(
|
49 |
+
n_fft=self.window_size,
|
50 |
+
hop_length=hop_size,
|
51 |
+
win_length=self.window_size,
|
52 |
+
window=window,
|
53 |
+
center=center,
|
54 |
+
pad_mode=pad_mode,
|
55 |
+
freeze_parameters=True,
|
56 |
+
)
|
57 |
+
|
58 |
+
def __call__(
|
59 |
+
self, output: torch.Tensor, target: torch.Tensor, **kwargs
|
60 |
+
) -> torch.Tensor:
|
61 |
+
r"""L1 loss in the time-domain and on the spectrogram.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
output: torch.Tensor
|
65 |
+
target: torch.Tensor
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
loss: torch.float
|
69 |
+
"""
|
70 |
+
|
71 |
+
# L1 loss in the time-domain.
|
72 |
+
wav_loss = l1_wav(output, target)
|
73 |
+
|
74 |
+
# L1 loss on the spectrogram.
|
75 |
+
sp_loss = l1(
|
76 |
+
self.wav_to_spectrogram(output, eps=1e-8),
|
77 |
+
self.wav_to_spectrogram(target, eps=1e-8),
|
78 |
+
)
|
79 |
+
|
80 |
+
# sp_loss /= math.sqrt(self.window_size)
|
81 |
+
# sp_loss *= 1.
|
82 |
+
|
83 |
+
# Total loss.
|
84 |
+
return wav_loss + sp_loss
|
85 |
+
|
86 |
+
return sp_loss
|
87 |
+
|
88 |
+
|
89 |
+
def get_loss_function(loss_type: str) -> Callable:
|
90 |
+
r"""Get loss function.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
loss_type: str
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
loss function: Callable
|
97 |
+
"""
|
98 |
+
|
99 |
+
if loss_type == "l1_wav":
|
100 |
+
return l1_wav
|
101 |
+
|
102 |
+
elif loss_type == "l1_wav_l1_sp":
|
103 |
+
return L1_Wav_L1_Sp()
|
104 |
+
|
105 |
+
else:
|
106 |
+
raise NotImplementedError
|
bytesep/models/__init__.py
ADDED
File without changes
|
bytesep/models/conditional_unet.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.optim as optim
|
11 |
+
from torch.optim.lr_scheduler import LambdaLR
|
12 |
+
from torchlibrosa.stft import STFT, ISTFT, magphase
|
13 |
+
|
14 |
+
from bytesep.models.pytorch_modules import (
|
15 |
+
Base,
|
16 |
+
init_bn,
|
17 |
+
init_embedding,
|
18 |
+
init_layer,
|
19 |
+
act,
|
20 |
+
Subband,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class ConvBlock(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
in_channels,
|
28 |
+
out_channels,
|
29 |
+
condition_size,
|
30 |
+
kernel_size,
|
31 |
+
activation,
|
32 |
+
momentum,
|
33 |
+
):
|
34 |
+
super(ConvBlock, self).__init__()
|
35 |
+
|
36 |
+
self.activation = activation
|
37 |
+
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
38 |
+
|
39 |
+
self.conv1 = nn.Conv2d(
|
40 |
+
in_channels=in_channels,
|
41 |
+
out_channels=out_channels,
|
42 |
+
kernel_size=kernel_size,
|
43 |
+
stride=(1, 1),
|
44 |
+
dilation=(1, 1),
|
45 |
+
padding=padding,
|
46 |
+
bias=False,
|
47 |
+
)
|
48 |
+
|
49 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
50 |
+
|
51 |
+
self.conv2 = nn.Conv2d(
|
52 |
+
in_channels=out_channels,
|
53 |
+
out_channels=out_channels,
|
54 |
+
kernel_size=kernel_size,
|
55 |
+
stride=(1, 1),
|
56 |
+
dilation=(1, 1),
|
57 |
+
padding=padding,
|
58 |
+
bias=False,
|
59 |
+
)
|
60 |
+
|
61 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
62 |
+
|
63 |
+
self.beta1 = nn.Linear(condition_size, out_channels, bias=True)
|
64 |
+
self.beta2 = nn.Linear(condition_size, out_channels, bias=True)
|
65 |
+
|
66 |
+
self.init_weights()
|
67 |
+
|
68 |
+
def init_weights(self):
|
69 |
+
init_layer(self.conv1)
|
70 |
+
init_layer(self.conv2)
|
71 |
+
init_bn(self.bn1)
|
72 |
+
init_bn(self.bn2)
|
73 |
+
init_embedding(self.beta1)
|
74 |
+
init_embedding(self.beta2)
|
75 |
+
|
76 |
+
def forward(self, x, condition):
|
77 |
+
|
78 |
+
b1 = self.beta1(condition)[:, :, None, None]
|
79 |
+
b2 = self.beta2(condition)[:, :, None, None]
|
80 |
+
|
81 |
+
x = act(self.bn1(self.conv1(x)) + b1, self.activation)
|
82 |
+
x = act(self.bn2(self.conv2(x)) + b2, self.activation)
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class EncoderBlock(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
in_channels,
|
90 |
+
out_channels,
|
91 |
+
condition_size,
|
92 |
+
kernel_size,
|
93 |
+
downsample,
|
94 |
+
activation,
|
95 |
+
momentum,
|
96 |
+
):
|
97 |
+
super(EncoderBlock, self).__init__()
|
98 |
+
|
99 |
+
self.conv_block = ConvBlock(
|
100 |
+
in_channels, out_channels, condition_size, kernel_size, activation, momentum
|
101 |
+
)
|
102 |
+
self.downsample = downsample
|
103 |
+
|
104 |
+
def forward(self, x, condition):
|
105 |
+
encoder = self.conv_block(x, condition)
|
106 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
107 |
+
return encoder_pool, encoder
|
108 |
+
|
109 |
+
|
110 |
+
class DecoderBlock(nn.Module):
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
in_channels,
|
114 |
+
out_channels,
|
115 |
+
condition_size,
|
116 |
+
kernel_size,
|
117 |
+
upsample,
|
118 |
+
activation,
|
119 |
+
momentum,
|
120 |
+
):
|
121 |
+
super(DecoderBlock, self).__init__()
|
122 |
+
self.kernel_size = kernel_size
|
123 |
+
self.stride = upsample
|
124 |
+
self.activation = activation
|
125 |
+
|
126 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
127 |
+
in_channels=in_channels,
|
128 |
+
out_channels=out_channels,
|
129 |
+
kernel_size=self.stride,
|
130 |
+
stride=self.stride,
|
131 |
+
padding=(0, 0),
|
132 |
+
bias=False,
|
133 |
+
dilation=(1, 1),
|
134 |
+
)
|
135 |
+
|
136 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
137 |
+
|
138 |
+
self.conv_block2 = ConvBlock(
|
139 |
+
out_channels * 2,
|
140 |
+
out_channels,
|
141 |
+
condition_size,
|
142 |
+
kernel_size,
|
143 |
+
activation,
|
144 |
+
momentum,
|
145 |
+
)
|
146 |
+
|
147 |
+
self.beta1 = nn.Linear(condition_size, out_channels, bias=True)
|
148 |
+
|
149 |
+
self.init_weights()
|
150 |
+
|
151 |
+
def init_weights(self):
|
152 |
+
init_layer(self.conv1)
|
153 |
+
init_bn(self.bn1)
|
154 |
+
init_embedding(self.beta1)
|
155 |
+
|
156 |
+
def forward(self, input_tensor, concat_tensor, condition):
|
157 |
+
b1 = self.beta1(condition)[:, :, None, None]
|
158 |
+
x = act(self.bn1(self.conv1(input_tensor)) + b1, self.activation)
|
159 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
160 |
+
x = self.conv_block2(x, condition)
|
161 |
+
return x
|
162 |
+
|
163 |
+
|
164 |
+
class ConditionalUNet(nn.Module, Base):
|
165 |
+
def __init__(self, input_channels, target_sources_num):
|
166 |
+
super(ConditionalUNet, self).__init__()
|
167 |
+
|
168 |
+
self.input_channels = input_channels
|
169 |
+
condition_size = target_sources_num
|
170 |
+
self.output_sources_num = 1
|
171 |
+
|
172 |
+
window_size = 2048
|
173 |
+
hop_size = 441
|
174 |
+
center = True
|
175 |
+
pad_mode = "reflect"
|
176 |
+
window = "hann"
|
177 |
+
activation = "relu"
|
178 |
+
momentum = 0.01
|
179 |
+
|
180 |
+
self.subbands_num = 4
|
181 |
+
self.K = 3 # outputs: |M|, cos∠M, sin∠M
|
182 |
+
|
183 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
|
184 |
+
|
185 |
+
self.stft = STFT(
|
186 |
+
n_fft=window_size,
|
187 |
+
hop_length=hop_size,
|
188 |
+
win_length=window_size,
|
189 |
+
window=window,
|
190 |
+
center=center,
|
191 |
+
pad_mode=pad_mode,
|
192 |
+
freeze_parameters=True,
|
193 |
+
)
|
194 |
+
|
195 |
+
self.istft = ISTFT(
|
196 |
+
n_fft=window_size,
|
197 |
+
hop_length=hop_size,
|
198 |
+
win_length=window_size,
|
199 |
+
window=window,
|
200 |
+
center=center,
|
201 |
+
pad_mode=pad_mode,
|
202 |
+
freeze_parameters=True,
|
203 |
+
)
|
204 |
+
|
205 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
206 |
+
|
207 |
+
self.subband = Subband(subbands_num=self.subbands_num)
|
208 |
+
|
209 |
+
self.encoder_block1 = EncoderBlock(
|
210 |
+
in_channels=input_channels * self.subbands_num,
|
211 |
+
out_channels=32,
|
212 |
+
condition_size=condition_size,
|
213 |
+
kernel_size=(3, 3),
|
214 |
+
downsample=(2, 2),
|
215 |
+
activation=activation,
|
216 |
+
momentum=momentum,
|
217 |
+
)
|
218 |
+
self.encoder_block2 = EncoderBlock(
|
219 |
+
in_channels=32,
|
220 |
+
out_channels=64,
|
221 |
+
condition_size=condition_size,
|
222 |
+
kernel_size=(3, 3),
|
223 |
+
downsample=(2, 2),
|
224 |
+
activation=activation,
|
225 |
+
momentum=momentum,
|
226 |
+
)
|
227 |
+
self.encoder_block3 = EncoderBlock(
|
228 |
+
in_channels=64,
|
229 |
+
out_channels=128,
|
230 |
+
condition_size=condition_size,
|
231 |
+
kernel_size=(3, 3),
|
232 |
+
downsample=(2, 2),
|
233 |
+
activation=activation,
|
234 |
+
momentum=momentum,
|
235 |
+
)
|
236 |
+
self.encoder_block4 = EncoderBlock(
|
237 |
+
in_channels=128,
|
238 |
+
out_channels=256,
|
239 |
+
condition_size=condition_size,
|
240 |
+
kernel_size=(3, 3),
|
241 |
+
downsample=(2, 2),
|
242 |
+
activation=activation,
|
243 |
+
momentum=momentum,
|
244 |
+
)
|
245 |
+
self.encoder_block5 = EncoderBlock(
|
246 |
+
in_channels=256,
|
247 |
+
out_channels=384,
|
248 |
+
condition_size=condition_size,
|
249 |
+
kernel_size=(3, 3),
|
250 |
+
downsample=(2, 2),
|
251 |
+
activation=activation,
|
252 |
+
momentum=momentum,
|
253 |
+
)
|
254 |
+
self.encoder_block6 = EncoderBlock(
|
255 |
+
in_channels=384,
|
256 |
+
out_channels=384,
|
257 |
+
condition_size=condition_size,
|
258 |
+
kernel_size=(3, 3),
|
259 |
+
downsample=(2, 2),
|
260 |
+
activation=activation,
|
261 |
+
momentum=momentum,
|
262 |
+
)
|
263 |
+
self.conv_block7 = ConvBlock(
|
264 |
+
in_channels=384,
|
265 |
+
out_channels=384,
|
266 |
+
condition_size=condition_size,
|
267 |
+
kernel_size=(3, 3),
|
268 |
+
activation=activation,
|
269 |
+
momentum=momentum,
|
270 |
+
)
|
271 |
+
self.decoder_block1 = DecoderBlock(
|
272 |
+
in_channels=384,
|
273 |
+
out_channels=384,
|
274 |
+
condition_size=condition_size,
|
275 |
+
kernel_size=(3, 3),
|
276 |
+
upsample=(2, 2),
|
277 |
+
activation=activation,
|
278 |
+
momentum=momentum,
|
279 |
+
)
|
280 |
+
self.decoder_block2 = DecoderBlock(
|
281 |
+
in_channels=384,
|
282 |
+
out_channels=384,
|
283 |
+
condition_size=condition_size,
|
284 |
+
kernel_size=(3, 3),
|
285 |
+
upsample=(2, 2),
|
286 |
+
activation=activation,
|
287 |
+
momentum=momentum,
|
288 |
+
)
|
289 |
+
self.decoder_block3 = DecoderBlock(
|
290 |
+
in_channels=384,
|
291 |
+
out_channels=256,
|
292 |
+
condition_size=condition_size,
|
293 |
+
kernel_size=(3, 3),
|
294 |
+
upsample=(2, 2),
|
295 |
+
activation=activation,
|
296 |
+
momentum=momentum,
|
297 |
+
)
|
298 |
+
self.decoder_block4 = DecoderBlock(
|
299 |
+
in_channels=256,
|
300 |
+
out_channels=128,
|
301 |
+
condition_size=condition_size,
|
302 |
+
kernel_size=(3, 3),
|
303 |
+
upsample=(2, 2),
|
304 |
+
activation=activation,
|
305 |
+
momentum=momentum,
|
306 |
+
)
|
307 |
+
self.decoder_block5 = DecoderBlock(
|
308 |
+
in_channels=128,
|
309 |
+
out_channels=64,
|
310 |
+
condition_size=condition_size,
|
311 |
+
kernel_size=(3, 3),
|
312 |
+
upsample=(2, 2),
|
313 |
+
activation=activation,
|
314 |
+
momentum=momentum,
|
315 |
+
)
|
316 |
+
self.decoder_block6 = DecoderBlock(
|
317 |
+
in_channels=64,
|
318 |
+
out_channels=32,
|
319 |
+
condition_size=condition_size,
|
320 |
+
kernel_size=(3, 3),
|
321 |
+
upsample=(2, 2),
|
322 |
+
activation=activation,
|
323 |
+
momentum=momentum,
|
324 |
+
)
|
325 |
+
|
326 |
+
self.after_conv_block1 = ConvBlock(
|
327 |
+
in_channels=32,
|
328 |
+
out_channels=32,
|
329 |
+
condition_size=condition_size,
|
330 |
+
kernel_size=(3, 3),
|
331 |
+
activation=activation,
|
332 |
+
momentum=momentum,
|
333 |
+
)
|
334 |
+
|
335 |
+
self.after_conv2 = nn.Conv2d(
|
336 |
+
in_channels=32,
|
337 |
+
out_channels=input_channels
|
338 |
+
* self.subbands_num
|
339 |
+
* self.output_sources_num
|
340 |
+
* self.K,
|
341 |
+
kernel_size=(1, 1),
|
342 |
+
stride=(1, 1),
|
343 |
+
padding=(0, 0),
|
344 |
+
bias=True,
|
345 |
+
)
|
346 |
+
|
347 |
+
self.init_weights()
|
348 |
+
|
349 |
+
def init_weights(self):
|
350 |
+
init_bn(self.bn0)
|
351 |
+
init_layer(self.after_conv2)
|
352 |
+
|
353 |
+
def feature_maps_to_wav(self, x, sp, sin_in, cos_in, audio_length):
|
354 |
+
|
355 |
+
batch_size, _, time_steps, freq_bins = x.shape
|
356 |
+
|
357 |
+
x = x.reshape(
|
358 |
+
batch_size,
|
359 |
+
self.output_sources_num,
|
360 |
+
self.input_channels,
|
361 |
+
self.K,
|
362 |
+
time_steps,
|
363 |
+
freq_bins,
|
364 |
+
)
|
365 |
+
# x: (batch_size, output_sources_num, input_channles, K, time_steps, freq_bins)
|
366 |
+
|
367 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
368 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
369 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
370 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
371 |
+
# mask_cos, mask_sin: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
|
372 |
+
|
373 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
374 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
375 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
376 |
+
out_cos = (
|
377 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
378 |
+
)
|
379 |
+
out_sin = (
|
380 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
381 |
+
)
|
382 |
+
# out_cos, out_sin: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
|
383 |
+
|
384 |
+
# Calculate |Y|.
|
385 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
|
386 |
+
# out_mag: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
|
387 |
+
|
388 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
389 |
+
out_real = out_mag * out_cos
|
390 |
+
out_imag = out_mag * out_sin
|
391 |
+
# out_real, out_imag: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
|
392 |
+
|
393 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
394 |
+
shape = (
|
395 |
+
batch_size * self.output_sources_num * self.input_channels,
|
396 |
+
1,
|
397 |
+
time_steps,
|
398 |
+
freq_bins,
|
399 |
+
)
|
400 |
+
out_real = out_real.reshape(shape)
|
401 |
+
out_imag = out_imag.reshape(shape)
|
402 |
+
|
403 |
+
# ISTFT.
|
404 |
+
wav_out = self.istft(out_real, out_imag, audio_length)
|
405 |
+
# (batch_size * output_sources_num * input_channels, segments_num)
|
406 |
+
|
407 |
+
# Reshape.
|
408 |
+
wav_out = wav_out.reshape(
|
409 |
+
batch_size, self.output_sources_num * self.input_channels, audio_length
|
410 |
+
)
|
411 |
+
# (batch_size, output_sources_num * input_channels, segments_num)
|
412 |
+
|
413 |
+
return wav_out
|
414 |
+
|
415 |
+
def forward(self, input_dict):
|
416 |
+
"""
|
417 |
+
Args:
|
418 |
+
input: (batch_size, segment_samples, channels_num)
|
419 |
+
|
420 |
+
Outputs:
|
421 |
+
output_dict: {
|
422 |
+
'wav': (batch_size, segment_samples, channels_num),
|
423 |
+
'sp': (batch_size, channels_num, time_steps, freq_bins)}
|
424 |
+
"""
|
425 |
+
|
426 |
+
mixture = input_dict['waveform']
|
427 |
+
condition = input_dict['condition']
|
428 |
+
|
429 |
+
sp, cos_in, sin_in = self.wav_to_spectrogram_phase(mixture)
|
430 |
+
"""(batch_size, channels_num, time_steps, freq_bins)"""
|
431 |
+
|
432 |
+
# Batch normalization
|
433 |
+
x = sp.transpose(1, 3)
|
434 |
+
x = self.bn0(x)
|
435 |
+
x = x.transpose(1, 3)
|
436 |
+
"""(batch_size, chanenls, time_steps, freq_bins)"""
|
437 |
+
|
438 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
439 |
+
origin_len = x.shape[2]
|
440 |
+
pad_len = (
|
441 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
442 |
+
- origin_len
|
443 |
+
)
|
444 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
445 |
+
"""(batch_size, channels, padded_time_steps, freq_bins)"""
|
446 |
+
|
447 |
+
# Let frequency bins be evenly divided by 2, e.g., 513 -> 512
|
448 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
|
449 |
+
|
450 |
+
x = self.subband.analysis(x)
|
451 |
+
|
452 |
+
# UNet
|
453 |
+
(x1_pool, x1) = self.encoder_block1(
|
454 |
+
x, condition
|
455 |
+
) # x1_pool: (bs, 32, T / 2, F / 2)
|
456 |
+
(x2_pool, x2) = self.encoder_block2(
|
457 |
+
x1_pool, condition
|
458 |
+
) # x2_pool: (bs, 64, T / 4, F / 4)
|
459 |
+
(x3_pool, x3) = self.encoder_block3(
|
460 |
+
x2_pool, condition
|
461 |
+
) # x3_pool: (bs, 128, T / 8, F / 8)
|
462 |
+
(x4_pool, x4) = self.encoder_block4(
|
463 |
+
x3_pool, condition
|
464 |
+
) # x4_pool: (bs, 256, T / 16, F / 16)
|
465 |
+
(x5_pool, x5) = self.encoder_block5(
|
466 |
+
x4_pool, condition
|
467 |
+
) # x5_pool: (bs, 512, T / 32, F / 32)
|
468 |
+
(x6_pool, x6) = self.encoder_block6(
|
469 |
+
x5_pool, condition
|
470 |
+
) # x6_pool: (bs, 1024, T / 64, F / 64)
|
471 |
+
x_center = self.conv_block7(x6_pool, condition) # (bs, 2048, T / 64, F / 64)
|
472 |
+
x7 = self.decoder_block1(x_center, x6, condition) # (bs, 1024, T / 32, F / 32)
|
473 |
+
x8 = self.decoder_block2(x7, x5, condition) # (bs, 512, T / 16, F / 16)
|
474 |
+
x9 = self.decoder_block3(x8, x4, condition) # (bs, 256, T / 8, F / 8)
|
475 |
+
x10 = self.decoder_block4(x9, x3, condition) # (bs, 128, T / 4, F / 4)
|
476 |
+
x11 = self.decoder_block5(x10, x2, condition) # (bs, 64, T / 2, F / 2)
|
477 |
+
x12 = self.decoder_block6(x11, x1, condition) # (bs, 32, T, F)
|
478 |
+
x = self.after_conv_block1(x12, condition) # (bs, 32, T, F)
|
479 |
+
x = self.after_conv2(x)
|
480 |
+
# (batch_size, input_channles * subbands_num * targets_num * k, T, F // subbands_num)
|
481 |
+
|
482 |
+
x = self.subband.synthesis(x)
|
483 |
+
# (batch_size, input_channles * targets_num * K, T, F)
|
484 |
+
|
485 |
+
# Recover shape
|
486 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
|
487 |
+
x = x[:, :, 0:origin_len, :] # (bs, feature_maps, T, F)
|
488 |
+
|
489 |
+
audio_length = mixture.shape[2]
|
490 |
+
|
491 |
+
separated_audio = self.feature_maps_to_wav(x, sp, sin_in, cos_in, audio_length)
|
492 |
+
# separated_audio: (batch_size, output_sources_num * input_channels, segments_num)
|
493 |
+
|
494 |
+
output_dict = {'waveform': separated_audio}
|
495 |
+
|
496 |
+
return output_dict
|
bytesep/models/lightning_modules.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from torch.optim.lr_scheduler import LambdaLR
|
8 |
+
|
9 |
+
|
10 |
+
class LitSourceSeparation(pl.LightningModule):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
batch_data_preprocessor,
|
14 |
+
model: nn.Module,
|
15 |
+
loss_function: Callable,
|
16 |
+
optimizer_type: str,
|
17 |
+
learning_rate: float,
|
18 |
+
lr_lambda: Callable,
|
19 |
+
):
|
20 |
+
r"""Pytorch Lightning wrapper of PyTorch model, including forward,
|
21 |
+
optimization of model, etc.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
batch_data_preprocessor: object, used for preparing inputs and
|
25 |
+
targets for training. E.g., BasicBatchDataPreprocessor is used
|
26 |
+
for preparing data in dictionary into tensor.
|
27 |
+
model: nn.Module
|
28 |
+
loss_function: function
|
29 |
+
learning_rate: float
|
30 |
+
lr_lambda: function
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.batch_data_preprocessor = batch_data_preprocessor
|
35 |
+
self.model = model
|
36 |
+
self.optimizer_type = optimizer_type
|
37 |
+
self.loss_function = loss_function
|
38 |
+
self.learning_rate = learning_rate
|
39 |
+
self.lr_lambda = lr_lambda
|
40 |
+
|
41 |
+
def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float:
|
42 |
+
r"""Forward a mini-batch data to model, calculate loss function, and
|
43 |
+
train for one step. A mini-batch data is evenly distributed to multiple
|
44 |
+
devices (if there are) for parallel training.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
batch_data_dict: e.g. {
|
48 |
+
'vocals': (batch_size, channels_num, segment_samples),
|
49 |
+
'accompaniment': (batch_size, channels_num, segment_samples),
|
50 |
+
'mixture': (batch_size, channels_num, segment_samples)
|
51 |
+
}
|
52 |
+
batch_idx: int
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
loss: float, loss function of this mini-batch
|
56 |
+
"""
|
57 |
+
input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict)
|
58 |
+
# input_dict: {
|
59 |
+
# 'waveform': (batch_size, channels_num, segment_samples),
|
60 |
+
# (if_exist) 'condition': (batch_size, channels_num),
|
61 |
+
# }
|
62 |
+
# target_dict: {
|
63 |
+
# 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
|
64 |
+
# }
|
65 |
+
|
66 |
+
# Forward.
|
67 |
+
self.model.train()
|
68 |
+
|
69 |
+
output_dict = self.model(input_dict)
|
70 |
+
# output_dict: {
|
71 |
+
# 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
|
72 |
+
# }
|
73 |
+
|
74 |
+
outputs = output_dict['waveform']
|
75 |
+
# outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples)
|
76 |
+
|
77 |
+
# Calculate loss.
|
78 |
+
loss = self.loss_function(
|
79 |
+
output=outputs,
|
80 |
+
target=target_dict['waveform'],
|
81 |
+
mixture=input_dict['waveform'],
|
82 |
+
)
|
83 |
+
|
84 |
+
return loss
|
85 |
+
|
86 |
+
def configure_optimizers(self) -> Any:
|
87 |
+
r"""Configure optimizer."""
|
88 |
+
|
89 |
+
if self.optimizer_type == "Adam":
|
90 |
+
optimizer = optim.Adam(
|
91 |
+
self.model.parameters(),
|
92 |
+
lr=self.learning_rate,
|
93 |
+
betas=(0.9, 0.999),
|
94 |
+
eps=1e-08,
|
95 |
+
weight_decay=0.0,
|
96 |
+
amsgrad=True,
|
97 |
+
)
|
98 |
+
|
99 |
+
elif self.optimizer_type == "AdamW":
|
100 |
+
optimizer = optim.AdamW(
|
101 |
+
self.model.parameters(),
|
102 |
+
lr=self.learning_rate,
|
103 |
+
betas=(0.9, 0.999),
|
104 |
+
eps=1e-08,
|
105 |
+
weight_decay=0.0,
|
106 |
+
amsgrad=True,
|
107 |
+
)
|
108 |
+
|
109 |
+
else:
|
110 |
+
raise NotImplementedError
|
111 |
+
|
112 |
+
scheduler = {
|
113 |
+
'scheduler': LambdaLR(optimizer, self.lr_lambda),
|
114 |
+
'interval': 'step',
|
115 |
+
'frequency': 1,
|
116 |
+
}
|
117 |
+
|
118 |
+
return [optimizer], [scheduler]
|
119 |
+
|
120 |
+
|
121 |
+
def get_model_class(model_type):
|
122 |
+
r"""Get model.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN'
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
nn.Module
|
129 |
+
"""
|
130 |
+
if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021':
|
131 |
+
from bytesep.models.resunet_ismir2021 import (
|
132 |
+
ResUNet143_DecouplePlusInplaceABN_ISMIR2021,
|
133 |
+
)
|
134 |
+
|
135 |
+
return ResUNet143_DecouplePlusInplaceABN_ISMIR2021
|
136 |
+
|
137 |
+
elif model_type == 'UNet':
|
138 |
+
from bytesep.models.unet import UNet
|
139 |
+
|
140 |
+
return UNet
|
141 |
+
|
142 |
+
elif model_type == 'UNetSubbandTime':
|
143 |
+
from bytesep.models.unet_subbandtime import UNetSubbandTime
|
144 |
+
|
145 |
+
return UNetSubbandTime
|
146 |
+
|
147 |
+
elif model_type == 'ResUNet143_Subbandtime':
|
148 |
+
from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime
|
149 |
+
|
150 |
+
return ResUNet143_Subbandtime
|
151 |
+
|
152 |
+
elif model_type == 'ResUNet143_DecouplePlus':
|
153 |
+
from bytesep.models.resunet import ResUNet143_DecouplePlus
|
154 |
+
|
155 |
+
return ResUNet143_DecouplePlus
|
156 |
+
|
157 |
+
elif model_type == 'ConditionalUNet':
|
158 |
+
from bytesep.models.conditional_unet import ConditionalUNet
|
159 |
+
|
160 |
+
return ConditionalUNet
|
161 |
+
|
162 |
+
elif model_type == 'LevelRNN':
|
163 |
+
from bytesep.models.levelrnn import LevelRNN
|
164 |
+
|
165 |
+
return LevelRNN
|
166 |
+
|
167 |
+
elif model_type == 'WavUNet':
|
168 |
+
from bytesep.models.wavunet import WavUNet
|
169 |
+
|
170 |
+
return WavUNet
|
171 |
+
|
172 |
+
elif model_type == 'WavUNetLevelRNN':
|
173 |
+
from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN
|
174 |
+
|
175 |
+
return WavUNetLevelRNN
|
176 |
+
|
177 |
+
elif model_type == 'TTnet':
|
178 |
+
from bytesep.models.ttnet import TTnet
|
179 |
+
|
180 |
+
return TTnet
|
181 |
+
|
182 |
+
elif model_type == 'TTnetNoTransformer':
|
183 |
+
from bytesep.models.ttnet_no_transformer import TTnetNoTransformer
|
184 |
+
|
185 |
+
return TTnetNoTransformer
|
186 |
+
|
187 |
+
else:
|
188 |
+
raise NotImplementedError
|
bytesep/models/pytorch_modules.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, NoReturn
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def init_embedding(layer: nn.Module) -> NoReturn:
|
10 |
+
r"""Initialize a Linear or Convolutional layer."""
|
11 |
+
nn.init.uniform_(layer.weight, -1.0, 1.0)
|
12 |
+
|
13 |
+
if hasattr(layer, 'bias'):
|
14 |
+
if layer.bias is not None:
|
15 |
+
layer.bias.data.fill_(0.0)
|
16 |
+
|
17 |
+
|
18 |
+
def init_layer(layer: nn.Module) -> NoReturn:
|
19 |
+
r"""Initialize a Linear or Convolutional layer."""
|
20 |
+
nn.init.xavier_uniform_(layer.weight)
|
21 |
+
|
22 |
+
if hasattr(layer, "bias"):
|
23 |
+
if layer.bias is not None:
|
24 |
+
layer.bias.data.fill_(0.0)
|
25 |
+
|
26 |
+
|
27 |
+
def init_bn(bn: nn.Module) -> NoReturn:
|
28 |
+
r"""Initialize a Batchnorm layer."""
|
29 |
+
bn.bias.data.fill_(0.0)
|
30 |
+
bn.weight.data.fill_(1.0)
|
31 |
+
bn.running_mean.data.fill_(0.0)
|
32 |
+
bn.running_var.data.fill_(1.0)
|
33 |
+
|
34 |
+
|
35 |
+
def act(x: torch.Tensor, activation: str) -> torch.Tensor:
|
36 |
+
|
37 |
+
if activation == "relu":
|
38 |
+
return F.relu_(x)
|
39 |
+
|
40 |
+
elif activation == "leaky_relu":
|
41 |
+
return F.leaky_relu_(x, negative_slope=0.01)
|
42 |
+
|
43 |
+
elif activation == "swish":
|
44 |
+
return x * torch.sigmoid(x)
|
45 |
+
|
46 |
+
else:
|
47 |
+
raise Exception("Incorrect activation!")
|
48 |
+
|
49 |
+
|
50 |
+
class Base:
|
51 |
+
def __init__(self):
|
52 |
+
r"""Base function for extracting spectrogram, cos, and sin, etc."""
|
53 |
+
pass
|
54 |
+
|
55 |
+
def spectrogram(self, input: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
|
56 |
+
r"""Calculate spectrogram.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
input: (batch_size, segments_num)
|
60 |
+
eps: float
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
spectrogram: (batch_size, time_steps, freq_bins)
|
64 |
+
"""
|
65 |
+
(real, imag) = self.stft(input)
|
66 |
+
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
|
67 |
+
|
68 |
+
def spectrogram_phase(
|
69 |
+
self, input: torch.Tensor, eps: float = 0.0
|
70 |
+
) -> List[torch.Tensor]:
|
71 |
+
r"""Calculate the magnitude, cos, and sin of the STFT of input.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
input: (batch_size, segments_num)
|
75 |
+
eps: float
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
mag: (batch_size, time_steps, freq_bins)
|
79 |
+
cos: (batch_size, time_steps, freq_bins)
|
80 |
+
sin: (batch_size, time_steps, freq_bins)
|
81 |
+
"""
|
82 |
+
(real, imag) = self.stft(input)
|
83 |
+
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
|
84 |
+
cos = real / mag
|
85 |
+
sin = imag / mag
|
86 |
+
return mag, cos, sin
|
87 |
+
|
88 |
+
def wav_to_spectrogram_phase(
|
89 |
+
self, input: torch.Tensor, eps: float = 1e-10
|
90 |
+
) -> List[torch.Tensor]:
|
91 |
+
r"""Convert waveforms to magnitude, cos, and sin of STFT.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
input: (batch_size, channels_num, segment_samples)
|
95 |
+
eps: float
|
96 |
+
|
97 |
+
Outputs:
|
98 |
+
mag: (batch_size, channels_num, time_steps, freq_bins)
|
99 |
+
cos: (batch_size, channels_num, time_steps, freq_bins)
|
100 |
+
sin: (batch_size, channels_num, time_steps, freq_bins)
|
101 |
+
"""
|
102 |
+
batch_size, channels_num, segment_samples = input.shape
|
103 |
+
|
104 |
+
# Reshape input with shapes of (n, segments_num) to meet the
|
105 |
+
# requirements of the stft function.
|
106 |
+
x = input.reshape(batch_size * channels_num, segment_samples)
|
107 |
+
|
108 |
+
mag, cos, sin = self.spectrogram_phase(x, eps=eps)
|
109 |
+
# mag, cos, sin: (batch_size * channels_num, 1, time_steps, freq_bins)
|
110 |
+
|
111 |
+
_, _, time_steps, freq_bins = mag.shape
|
112 |
+
mag = mag.reshape(batch_size, channels_num, time_steps, freq_bins)
|
113 |
+
cos = cos.reshape(batch_size, channels_num, time_steps, freq_bins)
|
114 |
+
sin = sin.reshape(batch_size, channels_num, time_steps, freq_bins)
|
115 |
+
|
116 |
+
return mag, cos, sin
|
117 |
+
|
118 |
+
def wav_to_spectrogram(
|
119 |
+
self, input: torch.Tensor, eps: float = 1e-10
|
120 |
+
) -> List[torch.Tensor]:
|
121 |
+
|
122 |
+
mag, cos, sin = self.wav_to_spectrogram_phase(input, eps)
|
123 |
+
return mag
|
124 |
+
|
125 |
+
|
126 |
+
class Subband:
|
127 |
+
def __init__(self, subbands_num: int):
|
128 |
+
r"""Warning!! This class is not used!!
|
129 |
+
|
130 |
+
This class does not work as good as [1] which split subbands in the
|
131 |
+
time-domain. Please refere to [1] for formal implementation.
|
132 |
+
|
133 |
+
[1] Liu, Haohe, et al. "Channel-wise subband input for better voice and
|
134 |
+
accompaniment separation on high resolution music." arXiv preprint arXiv:2008.05216 (2020).
|
135 |
+
|
136 |
+
Args:
|
137 |
+
subbands_num: int, e.g., 4
|
138 |
+
"""
|
139 |
+
self.subbands_num = subbands_num
|
140 |
+
|
141 |
+
def analysis(self, x: torch.Tensor) -> torch.Tensor:
|
142 |
+
r"""Analysis time-frequency representation into subbands. Stack the
|
143 |
+
subbands along the channel axis.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
x: (batch_size, channels_num, time_steps, freq_bins)
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
|
150 |
+
"""
|
151 |
+
batch_size, channels_num, time_steps, freq_bins = x.shape
|
152 |
+
|
153 |
+
x = x.reshape(
|
154 |
+
batch_size,
|
155 |
+
channels_num,
|
156 |
+
time_steps,
|
157 |
+
self.subbands_num,
|
158 |
+
freq_bins // self.subbands_num,
|
159 |
+
)
|
160 |
+
# x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)
|
161 |
+
|
162 |
+
x = x.transpose(2, 3)
|
163 |
+
|
164 |
+
output = x.reshape(
|
165 |
+
batch_size,
|
166 |
+
channels_num * self.subbands_num,
|
167 |
+
time_steps,
|
168 |
+
freq_bins // self.subbands_num,
|
169 |
+
)
|
170 |
+
# output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
|
171 |
+
|
172 |
+
return output
|
173 |
+
|
174 |
+
def synthesis(self, x: torch.Tensor) -> torch.Tensor:
|
175 |
+
r"""Synthesis subband time-frequency representations into original
|
176 |
+
time-frequency representation.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
x: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
output: (batch_size, channels_num, time_steps, freq_bins)
|
183 |
+
"""
|
184 |
+
batch_size, subband_channels_num, time_steps, subband_freq_bins = x.shape
|
185 |
+
|
186 |
+
channels_num = subband_channels_num // self.subbands_num
|
187 |
+
freq_bins = subband_freq_bins * self.subbands_num
|
188 |
+
|
189 |
+
x = x.reshape(
|
190 |
+
batch_size,
|
191 |
+
channels_num,
|
192 |
+
self.subbands_num,
|
193 |
+
time_steps,
|
194 |
+
subband_freq_bins,
|
195 |
+
)
|
196 |
+
# x: (batch_size, channels_num, subbands_num, time_steps, freq_bins // subbands_num)
|
197 |
+
|
198 |
+
x = x.transpose(2, 3)
|
199 |
+
# x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)
|
200 |
+
|
201 |
+
output = x.reshape(batch_size, channels_num, time_steps, freq_bins)
|
202 |
+
# x: (batch_size, channels_num, time_steps, freq_bins)
|
203 |
+
|
204 |
+
return output
|
bytesep/models/resunet.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchlibrosa.stft import ISTFT, STFT, magphase
|
6 |
+
|
7 |
+
from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer
|
8 |
+
|
9 |
+
|
10 |
+
class ConvBlockRes(nn.Module):
|
11 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
12 |
+
r"""Residual block."""
|
13 |
+
super(ConvBlockRes, self).__init__()
|
14 |
+
|
15 |
+
self.activation = activation
|
16 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
17 |
+
|
18 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
19 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
20 |
+
|
21 |
+
self.conv1 = nn.Conv2d(
|
22 |
+
in_channels=in_channels,
|
23 |
+
out_channels=out_channels,
|
24 |
+
kernel_size=kernel_size,
|
25 |
+
stride=(1, 1),
|
26 |
+
dilation=(1, 1),
|
27 |
+
padding=padding,
|
28 |
+
bias=False,
|
29 |
+
)
|
30 |
+
|
31 |
+
self.conv2 = nn.Conv2d(
|
32 |
+
in_channels=out_channels,
|
33 |
+
out_channels=out_channels,
|
34 |
+
kernel_size=kernel_size,
|
35 |
+
stride=(1, 1),
|
36 |
+
dilation=(1, 1),
|
37 |
+
padding=padding,
|
38 |
+
bias=False,
|
39 |
+
)
|
40 |
+
|
41 |
+
if in_channels != out_channels:
|
42 |
+
self.shortcut = nn.Conv2d(
|
43 |
+
in_channels=in_channels,
|
44 |
+
out_channels=out_channels,
|
45 |
+
kernel_size=(1, 1),
|
46 |
+
stride=(1, 1),
|
47 |
+
padding=(0, 0),
|
48 |
+
)
|
49 |
+
|
50 |
+
self.is_shortcut = True
|
51 |
+
else:
|
52 |
+
self.is_shortcut = False
|
53 |
+
|
54 |
+
self.init_weights()
|
55 |
+
|
56 |
+
def init_weights(self):
|
57 |
+
init_bn(self.bn1)
|
58 |
+
init_bn(self.bn2)
|
59 |
+
init_layer(self.conv1)
|
60 |
+
init_layer(self.conv2)
|
61 |
+
|
62 |
+
if self.is_shortcut:
|
63 |
+
init_layer(self.shortcut)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
origin = x
|
67 |
+
x = self.conv1(act(self.bn1(x), self.activation))
|
68 |
+
x = self.conv2(act(self.bn2(x), self.activation))
|
69 |
+
|
70 |
+
if self.is_shortcut:
|
71 |
+
return self.shortcut(origin) + x
|
72 |
+
else:
|
73 |
+
return origin + x
|
74 |
+
|
75 |
+
|
76 |
+
class EncoderBlockRes4B(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self, in_channels, out_channels, kernel_size, downsample, activation, momentum
|
79 |
+
):
|
80 |
+
r"""Encoder block, contains 8 convolutional layers."""
|
81 |
+
super(EncoderBlockRes4B, self).__init__()
|
82 |
+
|
83 |
+
self.conv_block1 = ConvBlockRes(
|
84 |
+
in_channels, out_channels, kernel_size, activation, momentum
|
85 |
+
)
|
86 |
+
self.conv_block2 = ConvBlockRes(
|
87 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
88 |
+
)
|
89 |
+
self.conv_block3 = ConvBlockRes(
|
90 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
91 |
+
)
|
92 |
+
self.conv_block4 = ConvBlockRes(
|
93 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
94 |
+
)
|
95 |
+
self.downsample = downsample
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
encoder = self.conv_block1(x)
|
99 |
+
encoder = self.conv_block2(encoder)
|
100 |
+
encoder = self.conv_block3(encoder)
|
101 |
+
encoder = self.conv_block4(encoder)
|
102 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
103 |
+
return encoder_pool, encoder
|
104 |
+
|
105 |
+
|
106 |
+
class DecoderBlockRes4B(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self, in_channels, out_channels, kernel_size, upsample, activation, momentum
|
109 |
+
):
|
110 |
+
r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
|
111 |
+
super(DecoderBlockRes4B, self).__init__()
|
112 |
+
self.kernel_size = kernel_size
|
113 |
+
self.stride = upsample
|
114 |
+
self.activation = activation
|
115 |
+
|
116 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
117 |
+
in_channels=in_channels,
|
118 |
+
out_channels=out_channels,
|
119 |
+
kernel_size=self.stride,
|
120 |
+
stride=self.stride,
|
121 |
+
padding=(0, 0),
|
122 |
+
bias=False,
|
123 |
+
dilation=(1, 1),
|
124 |
+
)
|
125 |
+
|
126 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
127 |
+
self.conv_block2 = ConvBlockRes(
|
128 |
+
out_channels * 2, out_channels, kernel_size, activation, momentum
|
129 |
+
)
|
130 |
+
self.conv_block3 = ConvBlockRes(
|
131 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
132 |
+
)
|
133 |
+
self.conv_block4 = ConvBlockRes(
|
134 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
135 |
+
)
|
136 |
+
self.conv_block5 = ConvBlockRes(
|
137 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
138 |
+
)
|
139 |
+
|
140 |
+
self.init_weights()
|
141 |
+
|
142 |
+
def init_weights(self):
|
143 |
+
init_bn(self.bn1)
|
144 |
+
init_layer(self.conv1)
|
145 |
+
|
146 |
+
def forward(self, input_tensor, concat_tensor):
|
147 |
+
x = self.conv1(act(self.bn1(input_tensor), self.activation))
|
148 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
149 |
+
x = self.conv_block2(x)
|
150 |
+
x = self.conv_block3(x)
|
151 |
+
x = self.conv_block4(x)
|
152 |
+
x = self.conv_block5(x)
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
class ResUNet143_DecouplePlus(nn.Module, Base):
|
157 |
+
def __init__(self, input_channels, target_sources_num):
|
158 |
+
super(ResUNet143_DecouplePlus, self).__init__()
|
159 |
+
|
160 |
+
self.input_channels = input_channels
|
161 |
+
self.target_sources_num = target_sources_num
|
162 |
+
|
163 |
+
window_size = 2048
|
164 |
+
hop_size = 441
|
165 |
+
center = True
|
166 |
+
pad_mode = "reflect"
|
167 |
+
window = "hann"
|
168 |
+
activation = "relu"
|
169 |
+
momentum = 0.01
|
170 |
+
|
171 |
+
self.subbands_num = 4
|
172 |
+
self.K = 4 # outputs: |M|, cos∠M, sin∠M, |M2|
|
173 |
+
|
174 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
|
175 |
+
|
176 |
+
self.stft = STFT(
|
177 |
+
n_fft=window_size,
|
178 |
+
hop_length=hop_size,
|
179 |
+
win_length=window_size,
|
180 |
+
window=window,
|
181 |
+
center=center,
|
182 |
+
pad_mode=pad_mode,
|
183 |
+
freeze_parameters=True,
|
184 |
+
)
|
185 |
+
|
186 |
+
self.istft = ISTFT(
|
187 |
+
n_fft=window_size,
|
188 |
+
hop_length=hop_size,
|
189 |
+
win_length=window_size,
|
190 |
+
window=window,
|
191 |
+
center=center,
|
192 |
+
pad_mode=pad_mode,
|
193 |
+
freeze_parameters=True,
|
194 |
+
)
|
195 |
+
|
196 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
197 |
+
|
198 |
+
self.subband = Subband(subbands_num=self.subbands_num)
|
199 |
+
|
200 |
+
self.encoder_block1 = EncoderBlockRes4B(
|
201 |
+
in_channels=input_channels * self.subbands_num,
|
202 |
+
out_channels=32,
|
203 |
+
kernel_size=(3, 3),
|
204 |
+
downsample=(2, 2),
|
205 |
+
activation=activation,
|
206 |
+
momentum=momentum,
|
207 |
+
)
|
208 |
+
self.encoder_block2 = EncoderBlockRes4B(
|
209 |
+
in_channels=32,
|
210 |
+
out_channels=64,
|
211 |
+
kernel_size=(3, 3),
|
212 |
+
downsample=(2, 2),
|
213 |
+
activation=activation,
|
214 |
+
momentum=momentum,
|
215 |
+
)
|
216 |
+
self.encoder_block3 = EncoderBlockRes4B(
|
217 |
+
in_channels=64,
|
218 |
+
out_channels=128,
|
219 |
+
kernel_size=(3, 3),
|
220 |
+
downsample=(2, 2),
|
221 |
+
activation=activation,
|
222 |
+
momentum=momentum,
|
223 |
+
)
|
224 |
+
self.encoder_block4 = EncoderBlockRes4B(
|
225 |
+
in_channels=128,
|
226 |
+
out_channels=256,
|
227 |
+
kernel_size=(3, 3),
|
228 |
+
downsample=(2, 2),
|
229 |
+
activation=activation,
|
230 |
+
momentum=momentum,
|
231 |
+
)
|
232 |
+
self.encoder_block5 = EncoderBlockRes4B(
|
233 |
+
in_channels=256,
|
234 |
+
out_channels=384,
|
235 |
+
kernel_size=(3, 3),
|
236 |
+
downsample=(2, 2),
|
237 |
+
activation=activation,
|
238 |
+
momentum=momentum,
|
239 |
+
)
|
240 |
+
self.encoder_block6 = EncoderBlockRes4B(
|
241 |
+
in_channels=384,
|
242 |
+
out_channels=384,
|
243 |
+
kernel_size=(3, 3),
|
244 |
+
downsample=(1, 2),
|
245 |
+
activation=activation,
|
246 |
+
momentum=momentum,
|
247 |
+
)
|
248 |
+
self.conv_block7a = EncoderBlockRes4B(
|
249 |
+
in_channels=384,
|
250 |
+
out_channels=384,
|
251 |
+
kernel_size=(3, 3),
|
252 |
+
downsample=(1, 1),
|
253 |
+
activation=activation,
|
254 |
+
momentum=momentum,
|
255 |
+
)
|
256 |
+
self.conv_block7b = EncoderBlockRes4B(
|
257 |
+
in_channels=384,
|
258 |
+
out_channels=384,
|
259 |
+
kernel_size=(3, 3),
|
260 |
+
downsample=(1, 1),
|
261 |
+
activation=activation,
|
262 |
+
momentum=momentum,
|
263 |
+
)
|
264 |
+
self.conv_block7c = EncoderBlockRes4B(
|
265 |
+
in_channels=384,
|
266 |
+
out_channels=384,
|
267 |
+
kernel_size=(3, 3),
|
268 |
+
downsample=(1, 1),
|
269 |
+
activation=activation,
|
270 |
+
momentum=momentum,
|
271 |
+
)
|
272 |
+
self.conv_block7d = EncoderBlockRes4B(
|
273 |
+
in_channels=384,
|
274 |
+
out_channels=384,
|
275 |
+
kernel_size=(3, 3),
|
276 |
+
downsample=(1, 1),
|
277 |
+
activation=activation,
|
278 |
+
momentum=momentum,
|
279 |
+
)
|
280 |
+
self.decoder_block1 = DecoderBlockRes4B(
|
281 |
+
in_channels=384,
|
282 |
+
out_channels=384,
|
283 |
+
kernel_size=(3, 3),
|
284 |
+
upsample=(1, 2),
|
285 |
+
activation=activation,
|
286 |
+
momentum=momentum,
|
287 |
+
)
|
288 |
+
self.decoder_block2 = DecoderBlockRes4B(
|
289 |
+
in_channels=384,
|
290 |
+
out_channels=384,
|
291 |
+
kernel_size=(3, 3),
|
292 |
+
upsample=(2, 2),
|
293 |
+
activation=activation,
|
294 |
+
momentum=momentum,
|
295 |
+
)
|
296 |
+
self.decoder_block3 = DecoderBlockRes4B(
|
297 |
+
in_channels=384,
|
298 |
+
out_channels=256,
|
299 |
+
kernel_size=(3, 3),
|
300 |
+
upsample=(2, 2),
|
301 |
+
activation=activation,
|
302 |
+
momentum=momentum,
|
303 |
+
)
|
304 |
+
self.decoder_block4 = DecoderBlockRes4B(
|
305 |
+
in_channels=256,
|
306 |
+
out_channels=128,
|
307 |
+
kernel_size=(3, 3),
|
308 |
+
upsample=(2, 2),
|
309 |
+
activation=activation,
|
310 |
+
momentum=momentum,
|
311 |
+
)
|
312 |
+
self.decoder_block5 = DecoderBlockRes4B(
|
313 |
+
in_channels=128,
|
314 |
+
out_channels=64,
|
315 |
+
kernel_size=(3, 3),
|
316 |
+
upsample=(2, 2),
|
317 |
+
activation=activation,
|
318 |
+
momentum=momentum,
|
319 |
+
)
|
320 |
+
self.decoder_block6 = DecoderBlockRes4B(
|
321 |
+
in_channels=64,
|
322 |
+
out_channels=32,
|
323 |
+
kernel_size=(3, 3),
|
324 |
+
upsample=(2, 2),
|
325 |
+
activation=activation,
|
326 |
+
momentum=momentum,
|
327 |
+
)
|
328 |
+
|
329 |
+
self.after_conv_block1 = EncoderBlockRes4B(
|
330 |
+
in_channels=32,
|
331 |
+
out_channels=32,
|
332 |
+
kernel_size=(3, 3),
|
333 |
+
downsample=(1, 1),
|
334 |
+
activation=activation,
|
335 |
+
momentum=momentum,
|
336 |
+
)
|
337 |
+
|
338 |
+
self.after_conv2 = nn.Conv2d(
|
339 |
+
in_channels=32,
|
340 |
+
out_channels=input_channels
|
341 |
+
* self.subbands_num
|
342 |
+
* target_sources_num
|
343 |
+
* self.K,
|
344 |
+
kernel_size=(1, 1),
|
345 |
+
stride=(1, 1),
|
346 |
+
padding=(0, 0),
|
347 |
+
bias=True,
|
348 |
+
)
|
349 |
+
|
350 |
+
self.init_weights()
|
351 |
+
|
352 |
+
def init_weights(self):
|
353 |
+
init_bn(self.bn0)
|
354 |
+
init_layer(self.after_conv2)
|
355 |
+
|
356 |
+
def feature_maps_to_wav(
|
357 |
+
self,
|
358 |
+
input_tensor: torch.Tensor,
|
359 |
+
sp: torch.Tensor,
|
360 |
+
sin_in: torch.Tensor,
|
361 |
+
cos_in: torch.Tensor,
|
362 |
+
audio_length: int,
|
363 |
+
) -> torch.Tensor:
|
364 |
+
r"""Convert feature maps to waveform.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
input_tensor: (batch_size, feature_maps, time_steps, freq_bins)
|
368 |
+
sp: (batch_size, feature_maps, time_steps, freq_bins)
|
369 |
+
sin_in: (batch_size, feature_maps, time_steps, freq_bins)
|
370 |
+
cos_in: (batch_size, feature_maps, time_steps, freq_bins)
|
371 |
+
|
372 |
+
Outputs:
|
373 |
+
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
|
374 |
+
"""
|
375 |
+
batch_size, _, time_steps, freq_bins = input_tensor.shape
|
376 |
+
|
377 |
+
x = input_tensor.reshape(
|
378 |
+
batch_size,
|
379 |
+
self.target_sources_num,
|
380 |
+
self.input_channels,
|
381 |
+
self.K,
|
382 |
+
time_steps,
|
383 |
+
freq_bins,
|
384 |
+
)
|
385 |
+
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
|
386 |
+
|
387 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
388 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
389 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
390 |
+
linear_mag = x[:, :, :, 3, :, :]
|
391 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
392 |
+
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
393 |
+
|
394 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
395 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
396 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
397 |
+
out_cos = (
|
398 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
399 |
+
)
|
400 |
+
out_sin = (
|
401 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
402 |
+
)
|
403 |
+
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
404 |
+
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
405 |
+
|
406 |
+
# Calculate |Y|.
|
407 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
|
408 |
+
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
409 |
+
|
410 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
411 |
+
out_real = out_mag * out_cos
|
412 |
+
out_imag = out_mag * out_sin
|
413 |
+
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
414 |
+
|
415 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
416 |
+
shape = (
|
417 |
+
batch_size * self.target_sources_num * self.input_channels,
|
418 |
+
1,
|
419 |
+
time_steps,
|
420 |
+
freq_bins,
|
421 |
+
)
|
422 |
+
out_real = out_real.reshape(shape)
|
423 |
+
out_imag = out_imag.reshape(shape)
|
424 |
+
|
425 |
+
# ISTFT.
|
426 |
+
x = self.istft(out_real, out_imag, audio_length)
|
427 |
+
# (batch_size * target_sources_num * input_channels, segments_num)
|
428 |
+
|
429 |
+
# Reshape.
|
430 |
+
waveform = x.reshape(
|
431 |
+
batch_size, self.target_sources_num * self.input_channels, audio_length
|
432 |
+
)
|
433 |
+
# (batch_size, target_sources_num * input_channels, segments_num)
|
434 |
+
|
435 |
+
return waveform
|
436 |
+
|
437 |
+
def forward(self, input_dict):
|
438 |
+
r"""
|
439 |
+
Args:
|
440 |
+
input: (batch_size, channels_num, segment_samples)
|
441 |
+
|
442 |
+
Outputs:
|
443 |
+
output_dict: {
|
444 |
+
'wav': (batch_size, channels_num, segment_samples)
|
445 |
+
}
|
446 |
+
"""
|
447 |
+
mixtures = input_dict['waveform']
|
448 |
+
# (batch_size, input_channels, segment_samples)
|
449 |
+
|
450 |
+
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
|
451 |
+
# mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
|
452 |
+
|
453 |
+
# Batch normalize on individual frequency bins.
|
454 |
+
x = mag.transpose(1, 3)
|
455 |
+
x = self.bn0(x)
|
456 |
+
x = x.transpose(1, 3)
|
457 |
+
"""(batch_size, input_channels, time_steps, freq_bins)"""
|
458 |
+
|
459 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
460 |
+
origin_len = x.shape[2]
|
461 |
+
pad_len = (
|
462 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
463 |
+
- origin_len
|
464 |
+
)
|
465 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
466 |
+
"""(batch_size, input_channels, padded_time_steps, freq_bins)"""
|
467 |
+
|
468 |
+
# Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024
|
469 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
|
470 |
+
|
471 |
+
x = self.subband.analysis(x)
|
472 |
+
# (bs, input_channels, T, F'), where F' = F // subbands_num
|
473 |
+
|
474 |
+
# UNet
|
475 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
|
476 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
|
477 |
+
(x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
|
478 |
+
(x4_pool, x4) = self.encoder_block4(
|
479 |
+
x3_pool
|
480 |
+
) # x4_pool: (bs, 256, T / 16, F / 16)
|
481 |
+
(x5_pool, x5) = self.encoder_block5(
|
482 |
+
x4_pool
|
483 |
+
) # x5_pool: (bs, 384, T / 32, F / 32)
|
484 |
+
(x6_pool, x6) = self.encoder_block6(
|
485 |
+
x5_pool
|
486 |
+
) # x6_pool: (bs, 384, T / 32, F / 64)
|
487 |
+
(x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
|
488 |
+
(x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
|
489 |
+
(x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
|
490 |
+
(x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
|
491 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
|
492 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
|
493 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
|
494 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
|
495 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
|
496 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
|
497 |
+
(x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
|
498 |
+
|
499 |
+
x = self.after_conv2(x) # (bs, channels * 3, T, F)
|
500 |
+
# (batch_size, input_channles * subbands_num * targets_num * k, T, F')
|
501 |
+
|
502 |
+
x = self.subband.synthesis(x)
|
503 |
+
# (batch_size, input_channles * targets_num * K, T, F)
|
504 |
+
|
505 |
+
# Recover shape
|
506 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
|
507 |
+
x = x[:, :, 0:origin_len, :] # (bs, feature_maps, time_steps, freq_bins)
|
508 |
+
|
509 |
+
audio_length = mixtures.shape[2]
|
510 |
+
|
511 |
+
separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
|
512 |
+
# separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
|
513 |
+
|
514 |
+
output_dict = {'waveform': separated_audio}
|
515 |
+
|
516 |
+
return output_dict
|
bytesep/models/resunet_ismir2021.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from inplace_abn.abn import InPlaceABNSync
|
6 |
+
from torchlibrosa.stft import ISTFT, STFT, magphase
|
7 |
+
|
8 |
+
from bytesep.models.pytorch_modules import Base, init_bn, init_layer
|
9 |
+
|
10 |
+
|
11 |
+
class ConvBlockRes(nn.Module):
|
12 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
13 |
+
r"""Residual block."""
|
14 |
+
super(ConvBlockRes, self).__init__()
|
15 |
+
|
16 |
+
self.activation = activation
|
17 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
18 |
+
|
19 |
+
# ABN is not used for bn1 because we found using abn1 will degrade performance.
|
20 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
21 |
+
|
22 |
+
self.abn2 = InPlaceABNSync(
|
23 |
+
num_features=out_channels, momentum=momentum, activation='leaky_relu'
|
24 |
+
)
|
25 |
+
|
26 |
+
self.conv1 = nn.Conv2d(
|
27 |
+
in_channels=in_channels,
|
28 |
+
out_channels=out_channels,
|
29 |
+
kernel_size=kernel_size,
|
30 |
+
stride=(1, 1),
|
31 |
+
dilation=(1, 1),
|
32 |
+
padding=padding,
|
33 |
+
bias=False,
|
34 |
+
)
|
35 |
+
|
36 |
+
self.conv2 = nn.Conv2d(
|
37 |
+
in_channels=out_channels,
|
38 |
+
out_channels=out_channels,
|
39 |
+
kernel_size=kernel_size,
|
40 |
+
stride=(1, 1),
|
41 |
+
dilation=(1, 1),
|
42 |
+
padding=padding,
|
43 |
+
bias=False,
|
44 |
+
)
|
45 |
+
|
46 |
+
if in_channels != out_channels:
|
47 |
+
self.shortcut = nn.Conv2d(
|
48 |
+
in_channels=in_channels,
|
49 |
+
out_channels=out_channels,
|
50 |
+
kernel_size=(1, 1),
|
51 |
+
stride=(1, 1),
|
52 |
+
padding=(0, 0),
|
53 |
+
)
|
54 |
+
self.is_shortcut = True
|
55 |
+
else:
|
56 |
+
self.is_shortcut = False
|
57 |
+
|
58 |
+
self.init_weights()
|
59 |
+
|
60 |
+
def init_weights(self):
|
61 |
+
init_bn(self.bn1)
|
62 |
+
init_layer(self.conv1)
|
63 |
+
init_layer(self.conv2)
|
64 |
+
|
65 |
+
if self.is_shortcut:
|
66 |
+
init_layer(self.shortcut)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
origin = x
|
70 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
71 |
+
x = self.conv2(self.abn2(x))
|
72 |
+
|
73 |
+
if self.is_shortcut:
|
74 |
+
return self.shortcut(origin) + x
|
75 |
+
else:
|
76 |
+
return origin + x
|
77 |
+
|
78 |
+
|
79 |
+
class EncoderBlockRes4B(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self, in_channels, out_channels, kernel_size, downsample, activation, momentum
|
82 |
+
):
|
83 |
+
r"""Encoder block, contains 8 convolutional layers."""
|
84 |
+
super(EncoderBlockRes4B, self).__init__()
|
85 |
+
|
86 |
+
self.conv_block1 = ConvBlockRes(
|
87 |
+
in_channels, out_channels, kernel_size, activation, momentum
|
88 |
+
)
|
89 |
+
self.conv_block2 = ConvBlockRes(
|
90 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
91 |
+
)
|
92 |
+
self.conv_block3 = ConvBlockRes(
|
93 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
94 |
+
)
|
95 |
+
self.conv_block4 = ConvBlockRes(
|
96 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
97 |
+
)
|
98 |
+
self.downsample = downsample
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
encoder = self.conv_block1(x)
|
102 |
+
encoder = self.conv_block2(encoder)
|
103 |
+
encoder = self.conv_block3(encoder)
|
104 |
+
encoder = self.conv_block4(encoder)
|
105 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
106 |
+
return encoder_pool, encoder
|
107 |
+
|
108 |
+
|
109 |
+
class DecoderBlockRes4B(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self, in_channels, out_channels, kernel_size, upsample, activation, momentum
|
112 |
+
):
|
113 |
+
r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
|
114 |
+
super(DecoderBlockRes4B, self).__init__()
|
115 |
+
self.kernel_size = kernel_size
|
116 |
+
self.stride = upsample
|
117 |
+
self.activation = activation
|
118 |
+
|
119 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
120 |
+
in_channels=in_channels,
|
121 |
+
out_channels=out_channels,
|
122 |
+
kernel_size=self.stride,
|
123 |
+
stride=self.stride,
|
124 |
+
padding=(0, 0),
|
125 |
+
bias=False,
|
126 |
+
dilation=(1, 1),
|
127 |
+
)
|
128 |
+
|
129 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
130 |
+
self.conv_block2 = ConvBlockRes(
|
131 |
+
out_channels * 2, out_channels, kernel_size, activation, momentum
|
132 |
+
)
|
133 |
+
self.conv_block3 = ConvBlockRes(
|
134 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
135 |
+
)
|
136 |
+
self.conv_block4 = ConvBlockRes(
|
137 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
138 |
+
)
|
139 |
+
self.conv_block5 = ConvBlockRes(
|
140 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
141 |
+
)
|
142 |
+
|
143 |
+
self.init_weights()
|
144 |
+
|
145 |
+
def init_weights(self):
|
146 |
+
init_bn(self.bn1)
|
147 |
+
init_layer(self.conv1)
|
148 |
+
|
149 |
+
def forward(self, input_tensor, concat_tensor):
|
150 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
151 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
152 |
+
x = self.conv_block2(x)
|
153 |
+
x = self.conv_block3(x)
|
154 |
+
x = self.conv_block4(x)
|
155 |
+
x = self.conv_block5(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class ResUNet143_DecouplePlusInplaceABN_ISMIR2021(nn.Module, Base):
|
160 |
+
def __init__(self, input_channels, target_sources_num):
|
161 |
+
super(ResUNet143_DecouplePlusInplaceABN_ISMIR2021, self).__init__()
|
162 |
+
|
163 |
+
self.input_channels = input_channels
|
164 |
+
self.target_sources_num = target_sources_num
|
165 |
+
|
166 |
+
window_size = 2048
|
167 |
+
hop_size = 441
|
168 |
+
center = True
|
169 |
+
pad_mode = 'reflect'
|
170 |
+
window = 'hann'
|
171 |
+
activation = 'leaky_relu'
|
172 |
+
momentum = 0.01
|
173 |
+
|
174 |
+
self.subbands_num = 1
|
175 |
+
|
176 |
+
assert (
|
177 |
+
self.subbands_num == 1
|
178 |
+
), "Using subbands_num > 1 on spectrogram \
|
179 |
+
will lead to unexpected performance sometimes. Suggest to use \
|
180 |
+
subband method on waveform."
|
181 |
+
|
182 |
+
# Downsample rate along the time axis.
|
183 |
+
self.K = 4 # outputs: |M|, cos∠M, sin∠M, Q
|
184 |
+
self.time_downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks}
|
185 |
+
|
186 |
+
self.stft = STFT(
|
187 |
+
n_fft=window_size,
|
188 |
+
hop_length=hop_size,
|
189 |
+
win_length=window_size,
|
190 |
+
window=window,
|
191 |
+
center=center,
|
192 |
+
pad_mode=pad_mode,
|
193 |
+
freeze_parameters=True,
|
194 |
+
)
|
195 |
+
|
196 |
+
self.istft = ISTFT(
|
197 |
+
n_fft=window_size,
|
198 |
+
hop_length=hop_size,
|
199 |
+
win_length=window_size,
|
200 |
+
window=window,
|
201 |
+
center=center,
|
202 |
+
pad_mode=pad_mode,
|
203 |
+
freeze_parameters=True,
|
204 |
+
)
|
205 |
+
|
206 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
207 |
+
|
208 |
+
self.encoder_block1 = EncoderBlockRes4B(
|
209 |
+
in_channels=input_channels * self.subbands_num,
|
210 |
+
out_channels=32,
|
211 |
+
kernel_size=(3, 3),
|
212 |
+
downsample=(2, 2),
|
213 |
+
activation=activation,
|
214 |
+
momentum=momentum,
|
215 |
+
)
|
216 |
+
self.encoder_block2 = EncoderBlockRes4B(
|
217 |
+
in_channels=32,
|
218 |
+
out_channels=64,
|
219 |
+
kernel_size=(3, 3),
|
220 |
+
downsample=(2, 2),
|
221 |
+
activation=activation,
|
222 |
+
momentum=momentum,
|
223 |
+
)
|
224 |
+
self.encoder_block3 = EncoderBlockRes4B(
|
225 |
+
in_channels=64,
|
226 |
+
out_channels=128,
|
227 |
+
kernel_size=(3, 3),
|
228 |
+
downsample=(2, 2),
|
229 |
+
activation=activation,
|
230 |
+
momentum=momentum,
|
231 |
+
)
|
232 |
+
self.encoder_block4 = EncoderBlockRes4B(
|
233 |
+
in_channels=128,
|
234 |
+
out_channels=256,
|
235 |
+
kernel_size=(3, 3),
|
236 |
+
downsample=(2, 2),
|
237 |
+
activation=activation,
|
238 |
+
momentum=momentum,
|
239 |
+
)
|
240 |
+
self.encoder_block5 = EncoderBlockRes4B(
|
241 |
+
in_channels=256,
|
242 |
+
out_channels=384,
|
243 |
+
kernel_size=(3, 3),
|
244 |
+
downsample=(2, 2),
|
245 |
+
activation=activation,
|
246 |
+
momentum=momentum,
|
247 |
+
)
|
248 |
+
self.encoder_block6 = EncoderBlockRes4B(
|
249 |
+
in_channels=384,
|
250 |
+
out_channels=384,
|
251 |
+
kernel_size=(3, 3),
|
252 |
+
downsample=(1, 2),
|
253 |
+
activation=activation,
|
254 |
+
momentum=momentum,
|
255 |
+
)
|
256 |
+
self.conv_block7a = EncoderBlockRes4B(
|
257 |
+
in_channels=384,
|
258 |
+
out_channels=384,
|
259 |
+
kernel_size=(3, 3),
|
260 |
+
downsample=(1, 1),
|
261 |
+
activation=activation,
|
262 |
+
momentum=momentum,
|
263 |
+
)
|
264 |
+
self.conv_block7b = EncoderBlockRes4B(
|
265 |
+
in_channels=384,
|
266 |
+
out_channels=384,
|
267 |
+
kernel_size=(3, 3),
|
268 |
+
downsample=(1, 1),
|
269 |
+
activation=activation,
|
270 |
+
momentum=momentum,
|
271 |
+
)
|
272 |
+
self.conv_block7c = EncoderBlockRes4B(
|
273 |
+
in_channels=384,
|
274 |
+
out_channels=384,
|
275 |
+
kernel_size=(3, 3),
|
276 |
+
downsample=(1, 1),
|
277 |
+
activation=activation,
|
278 |
+
momentum=momentum,
|
279 |
+
)
|
280 |
+
self.conv_block7d = EncoderBlockRes4B(
|
281 |
+
in_channels=384,
|
282 |
+
out_channels=384,
|
283 |
+
kernel_size=(3, 3),
|
284 |
+
downsample=(1, 1),
|
285 |
+
activation=activation,
|
286 |
+
momentum=momentum,
|
287 |
+
)
|
288 |
+
self.decoder_block1 = DecoderBlockRes4B(
|
289 |
+
in_channels=384,
|
290 |
+
out_channels=384,
|
291 |
+
kernel_size=(3, 3),
|
292 |
+
upsample=(1, 2),
|
293 |
+
activation=activation,
|
294 |
+
momentum=momentum,
|
295 |
+
)
|
296 |
+
self.decoder_block2 = DecoderBlockRes4B(
|
297 |
+
in_channels=384,
|
298 |
+
out_channels=384,
|
299 |
+
kernel_size=(3, 3),
|
300 |
+
upsample=(2, 2),
|
301 |
+
activation=activation,
|
302 |
+
momentum=momentum,
|
303 |
+
)
|
304 |
+
self.decoder_block3 = DecoderBlockRes4B(
|
305 |
+
in_channels=384,
|
306 |
+
out_channels=256,
|
307 |
+
kernel_size=(3, 3),
|
308 |
+
upsample=(2, 2),
|
309 |
+
activation=activation,
|
310 |
+
momentum=momentum,
|
311 |
+
)
|
312 |
+
self.decoder_block4 = DecoderBlockRes4B(
|
313 |
+
in_channels=256,
|
314 |
+
out_channels=128,
|
315 |
+
kernel_size=(3, 3),
|
316 |
+
upsample=(2, 2),
|
317 |
+
activation=activation,
|
318 |
+
momentum=momentum,
|
319 |
+
)
|
320 |
+
self.decoder_block5 = DecoderBlockRes4B(
|
321 |
+
in_channels=128,
|
322 |
+
out_channels=64,
|
323 |
+
kernel_size=(3, 3),
|
324 |
+
upsample=(2, 2),
|
325 |
+
activation=activation,
|
326 |
+
momentum=momentum,
|
327 |
+
)
|
328 |
+
self.decoder_block6 = DecoderBlockRes4B(
|
329 |
+
in_channels=64,
|
330 |
+
out_channels=32,
|
331 |
+
kernel_size=(3, 3),
|
332 |
+
upsample=(2, 2),
|
333 |
+
activation=activation,
|
334 |
+
momentum=momentum,
|
335 |
+
)
|
336 |
+
|
337 |
+
self.after_conv_block1 = EncoderBlockRes4B(
|
338 |
+
in_channels=32,
|
339 |
+
out_channels=32,
|
340 |
+
kernel_size=(3, 3),
|
341 |
+
downsample=(1, 1),
|
342 |
+
activation=activation,
|
343 |
+
momentum=momentum,
|
344 |
+
)
|
345 |
+
|
346 |
+
self.after_conv2 = nn.Conv2d(
|
347 |
+
in_channels=32,
|
348 |
+
out_channels=target_sources_num
|
349 |
+
* input_channels
|
350 |
+
* self.K
|
351 |
+
* self.subbands_num,
|
352 |
+
kernel_size=(1, 1),
|
353 |
+
stride=(1, 1),
|
354 |
+
padding=(0, 0),
|
355 |
+
bias=True,
|
356 |
+
)
|
357 |
+
|
358 |
+
self.init_weights()
|
359 |
+
|
360 |
+
def init_weights(self):
|
361 |
+
init_bn(self.bn0)
|
362 |
+
init_layer(self.after_conv2)
|
363 |
+
|
364 |
+
def feature_maps_to_wav(
|
365 |
+
self,
|
366 |
+
input_tensor: torch.Tensor,
|
367 |
+
sp: torch.Tensor,
|
368 |
+
sin_in: torch.Tensor,
|
369 |
+
cos_in: torch.Tensor,
|
370 |
+
audio_length: int,
|
371 |
+
) -> torch.Tensor:
|
372 |
+
r"""Convert feature maps to waveform.
|
373 |
+
|
374 |
+
Args:
|
375 |
+
input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
|
376 |
+
sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
377 |
+
sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
378 |
+
cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
379 |
+
|
380 |
+
Outputs:
|
381 |
+
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
|
382 |
+
"""
|
383 |
+
batch_size, _, time_steps, freq_bins = input_tensor.shape
|
384 |
+
|
385 |
+
x = input_tensor.reshape(
|
386 |
+
batch_size,
|
387 |
+
self.target_sources_num,
|
388 |
+
self.input_channels,
|
389 |
+
self.K,
|
390 |
+
time_steps,
|
391 |
+
freq_bins,
|
392 |
+
)
|
393 |
+
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
|
394 |
+
|
395 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
396 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
397 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
398 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
399 |
+
linear_mag = x[:, :, :, 3, :, :]
|
400 |
+
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
401 |
+
|
402 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
403 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
404 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
405 |
+
out_cos = (
|
406 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
407 |
+
)
|
408 |
+
out_sin = (
|
409 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
410 |
+
)
|
411 |
+
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
412 |
+
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
413 |
+
|
414 |
+
# Calculate |Y|.
|
415 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
|
416 |
+
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
417 |
+
|
418 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
419 |
+
out_real = out_mag * out_cos
|
420 |
+
out_imag = out_mag * out_sin
|
421 |
+
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
422 |
+
|
423 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
424 |
+
shape = (
|
425 |
+
batch_size * self.target_sources_num * self.input_channels,
|
426 |
+
1,
|
427 |
+
time_steps,
|
428 |
+
freq_bins,
|
429 |
+
)
|
430 |
+
out_real = out_real.reshape(shape)
|
431 |
+
out_imag = out_imag.reshape(shape)
|
432 |
+
|
433 |
+
# ISTFT.
|
434 |
+
x = self.istft(out_real, out_imag, audio_length)
|
435 |
+
# (batch_size * target_sources_num * input_channels, segments_num)
|
436 |
+
|
437 |
+
# Reshape.
|
438 |
+
waveform = x.reshape(
|
439 |
+
batch_size, self.target_sources_num * self.input_channels, audio_length
|
440 |
+
)
|
441 |
+
# (batch_size, target_sources_num * input_channels, segments_num)
|
442 |
+
|
443 |
+
return waveform
|
444 |
+
|
445 |
+
def forward(self, input_dict):
|
446 |
+
r"""Forward data into the module.
|
447 |
+
|
448 |
+
Args:
|
449 |
+
input_dict: dict, e.g., {
|
450 |
+
waveform: (batch_size, input_channels, segment_samples),
|
451 |
+
...,
|
452 |
+
}
|
453 |
+
|
454 |
+
Outputs:
|
455 |
+
output_dict: dict, e.g., {
|
456 |
+
'waveform': (batch_size, input_channels, segment_samples),
|
457 |
+
...,
|
458 |
+
}
|
459 |
+
"""
|
460 |
+
mixtures = input_dict['waveform']
|
461 |
+
# (batch_size, input_channels, segment_samples)
|
462 |
+
|
463 |
+
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
|
464 |
+
# mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
|
465 |
+
|
466 |
+
# Batch normalize on individual frequency bins.
|
467 |
+
x = mag.transpose(1, 3)
|
468 |
+
x = self.bn0(x)
|
469 |
+
x = x.transpose(1, 3)
|
470 |
+
# x: (batch_size, input_channels, time_steps, freq_bins)
|
471 |
+
|
472 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
473 |
+
origin_len = x.shape[2]
|
474 |
+
pad_len = (
|
475 |
+
int(np.ceil(x.shape[2] / self.time_downsample_ratio))
|
476 |
+
* self.time_downsample_ratio
|
477 |
+
- origin_len
|
478 |
+
)
|
479 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
480 |
+
# (batch_size, channels, padded_time_steps, freq_bins)
|
481 |
+
|
482 |
+
# Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024.
|
483 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
|
484 |
+
|
485 |
+
if self.subbands_num > 1:
|
486 |
+
x = self.subband.analysis(x)
|
487 |
+
# (bs, input_channels, T, F'), where F' = F // subbands_num
|
488 |
+
|
489 |
+
# UNet
|
490 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
|
491 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
|
492 |
+
(x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
|
493 |
+
(x4_pool, x4) = self.encoder_block4(
|
494 |
+
x3_pool
|
495 |
+
) # x4_pool: (bs, 256, T / 16, F / 16)
|
496 |
+
(x5_pool, x5) = self.encoder_block5(
|
497 |
+
x4_pool
|
498 |
+
) # x5_pool: (bs, 384, T / 32, F / 32)
|
499 |
+
(x6_pool, x6) = self.encoder_block6(
|
500 |
+
x5_pool
|
501 |
+
) # x6_pool: (bs, 384, T / 32, F / 64)
|
502 |
+
(x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
|
503 |
+
(x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
|
504 |
+
(x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
|
505 |
+
(x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
|
506 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
|
507 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
|
508 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
|
509 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
|
510 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
|
511 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
|
512 |
+
(x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
|
513 |
+
|
514 |
+
x = self.after_conv2(x) # (bs, channels * 3, T, F)
|
515 |
+
# (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F')
|
516 |
+
|
517 |
+
if self.subbands_num > 1:
|
518 |
+
x = self.subband.synthesis(x)
|
519 |
+
# (batch_size, target_sources_num * input_channles * self.K, T, F)
|
520 |
+
|
521 |
+
# Recover shape
|
522 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
|
523 |
+
|
524 |
+
x = x[:, :, 0:origin_len, :]
|
525 |
+
# (batch_size, target_sources_num * input_channles * self.K, T, F)
|
526 |
+
|
527 |
+
audio_length = mixtures.shape[2]
|
528 |
+
|
529 |
+
separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
|
530 |
+
# separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
|
531 |
+
|
532 |
+
output_dict = {'waveform': separated_audio}
|
533 |
+
|
534 |
+
return output_dict
|
bytesep/models/resunet_subbandtime.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchlibrosa.stft import ISTFT, STFT, magphase
|
6 |
+
|
7 |
+
from bytesep.models.pytorch_modules import Base, init_bn, init_layer
|
8 |
+
from bytesep.models.subband_tools.pqmf import PQMF
|
9 |
+
|
10 |
+
|
11 |
+
class ConvBlockRes(nn.Module):
|
12 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
13 |
+
r"""Residual block."""
|
14 |
+
super(ConvBlockRes, self).__init__()
|
15 |
+
|
16 |
+
self.activation = activation
|
17 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
18 |
+
|
19 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
20 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
21 |
+
|
22 |
+
self.conv1 = nn.Conv2d(
|
23 |
+
in_channels=in_channels,
|
24 |
+
out_channels=out_channels,
|
25 |
+
kernel_size=kernel_size,
|
26 |
+
stride=(1, 1),
|
27 |
+
dilation=(1, 1),
|
28 |
+
padding=padding,
|
29 |
+
bias=False,
|
30 |
+
)
|
31 |
+
|
32 |
+
self.conv2 = nn.Conv2d(
|
33 |
+
in_channels=out_channels,
|
34 |
+
out_channels=out_channels,
|
35 |
+
kernel_size=kernel_size,
|
36 |
+
stride=(1, 1),
|
37 |
+
dilation=(1, 1),
|
38 |
+
padding=padding,
|
39 |
+
bias=False,
|
40 |
+
)
|
41 |
+
|
42 |
+
if in_channels != out_channels:
|
43 |
+
self.shortcut = nn.Conv2d(
|
44 |
+
in_channels=in_channels,
|
45 |
+
out_channels=out_channels,
|
46 |
+
kernel_size=(1, 1),
|
47 |
+
stride=(1, 1),
|
48 |
+
padding=(0, 0),
|
49 |
+
)
|
50 |
+
self.is_shortcut = True
|
51 |
+
else:
|
52 |
+
self.is_shortcut = False
|
53 |
+
|
54 |
+
self.init_weights()
|
55 |
+
|
56 |
+
def init_weights(self):
|
57 |
+
init_bn(self.bn1)
|
58 |
+
init_bn(self.bn2)
|
59 |
+
init_layer(self.conv1)
|
60 |
+
init_layer(self.conv2)
|
61 |
+
|
62 |
+
if self.is_shortcut:
|
63 |
+
init_layer(self.shortcut)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
origin = x
|
67 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
68 |
+
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
69 |
+
|
70 |
+
if self.is_shortcut:
|
71 |
+
return self.shortcut(origin) + x
|
72 |
+
else:
|
73 |
+
return origin + x
|
74 |
+
|
75 |
+
|
76 |
+
class EncoderBlockRes4B(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self, in_channels, out_channels, kernel_size, downsample, activation, momentum
|
79 |
+
):
|
80 |
+
r"""Encoder block, contains 8 convolutional layers."""
|
81 |
+
super(EncoderBlockRes4B, self).__init__()
|
82 |
+
|
83 |
+
self.conv_block1 = ConvBlockRes(
|
84 |
+
in_channels, out_channels, kernel_size, activation, momentum
|
85 |
+
)
|
86 |
+
self.conv_block2 = ConvBlockRes(
|
87 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
88 |
+
)
|
89 |
+
self.conv_block3 = ConvBlockRes(
|
90 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
91 |
+
)
|
92 |
+
self.conv_block4 = ConvBlockRes(
|
93 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
94 |
+
)
|
95 |
+
self.downsample = downsample
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
encoder = self.conv_block1(x)
|
99 |
+
encoder = self.conv_block2(encoder)
|
100 |
+
encoder = self.conv_block3(encoder)
|
101 |
+
encoder = self.conv_block4(encoder)
|
102 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
103 |
+
return encoder_pool, encoder
|
104 |
+
|
105 |
+
|
106 |
+
class DecoderBlockRes4B(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self, in_channels, out_channels, kernel_size, upsample, activation, momentum
|
109 |
+
):
|
110 |
+
r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
|
111 |
+
super(DecoderBlockRes4B, self).__init__()
|
112 |
+
self.kernel_size = kernel_size
|
113 |
+
self.stride = upsample
|
114 |
+
self.activation = activation
|
115 |
+
|
116 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
117 |
+
in_channels=in_channels,
|
118 |
+
out_channels=out_channels,
|
119 |
+
kernel_size=self.stride,
|
120 |
+
stride=self.stride,
|
121 |
+
padding=(0, 0),
|
122 |
+
bias=False,
|
123 |
+
dilation=(1, 1),
|
124 |
+
)
|
125 |
+
|
126 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
127 |
+
self.conv_block2 = ConvBlockRes(
|
128 |
+
out_channels * 2, out_channels, kernel_size, activation, momentum
|
129 |
+
)
|
130 |
+
self.conv_block3 = ConvBlockRes(
|
131 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
132 |
+
)
|
133 |
+
self.conv_block4 = ConvBlockRes(
|
134 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
135 |
+
)
|
136 |
+
self.conv_block5 = ConvBlockRes(
|
137 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
138 |
+
)
|
139 |
+
|
140 |
+
self.init_weights()
|
141 |
+
|
142 |
+
def init_weights(self):
|
143 |
+
init_bn(self.bn1)
|
144 |
+
init_layer(self.conv1)
|
145 |
+
|
146 |
+
def forward(self, input_tensor, concat_tensor):
|
147 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
148 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
149 |
+
x = self.conv_block2(x)
|
150 |
+
x = self.conv_block3(x)
|
151 |
+
x = self.conv_block4(x)
|
152 |
+
x = self.conv_block5(x)
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
class ResUNet143_Subbandtime(nn.Module, Base):
|
157 |
+
def __init__(self, input_channels, target_sources_num):
|
158 |
+
super(ResUNet143_Subbandtime, self).__init__()
|
159 |
+
|
160 |
+
self.input_channels = input_channels
|
161 |
+
self.target_sources_num = target_sources_num
|
162 |
+
|
163 |
+
window_size = 512
|
164 |
+
hop_size = 110
|
165 |
+
center = True
|
166 |
+
pad_mode = "reflect"
|
167 |
+
window = "hann"
|
168 |
+
activation = "leaky_relu"
|
169 |
+
momentum = 0.01
|
170 |
+
|
171 |
+
self.subbands_num = 4
|
172 |
+
self.K = 4 # outputs: |M|, cos∠M, sin∠M, Q
|
173 |
+
|
174 |
+
self.downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks}
|
175 |
+
|
176 |
+
self.pqmf = PQMF(
|
177 |
+
N=self.subbands_num,
|
178 |
+
M=64,
|
179 |
+
project_root='bytesep/models/subband_tools/filters',
|
180 |
+
)
|
181 |
+
|
182 |
+
self.stft = STFT(
|
183 |
+
n_fft=window_size,
|
184 |
+
hop_length=hop_size,
|
185 |
+
win_length=window_size,
|
186 |
+
window=window,
|
187 |
+
center=center,
|
188 |
+
pad_mode=pad_mode,
|
189 |
+
freeze_parameters=True,
|
190 |
+
)
|
191 |
+
|
192 |
+
self.istft = ISTFT(
|
193 |
+
n_fft=window_size,
|
194 |
+
hop_length=hop_size,
|
195 |
+
win_length=window_size,
|
196 |
+
window=window,
|
197 |
+
center=center,
|
198 |
+
pad_mode=pad_mode,
|
199 |
+
freeze_parameters=True,
|
200 |
+
)
|
201 |
+
|
202 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
203 |
+
|
204 |
+
self.encoder_block1 = EncoderBlockRes4B(
|
205 |
+
in_channels=input_channels * self.subbands_num,
|
206 |
+
out_channels=32,
|
207 |
+
kernel_size=(3, 3),
|
208 |
+
downsample=(2, 2),
|
209 |
+
activation=activation,
|
210 |
+
momentum=momentum,
|
211 |
+
)
|
212 |
+
self.encoder_block2 = EncoderBlockRes4B(
|
213 |
+
in_channels=32,
|
214 |
+
out_channels=64,
|
215 |
+
kernel_size=(3, 3),
|
216 |
+
downsample=(2, 2),
|
217 |
+
activation=activation,
|
218 |
+
momentum=momentum,
|
219 |
+
)
|
220 |
+
self.encoder_block3 = EncoderBlockRes4B(
|
221 |
+
in_channels=64,
|
222 |
+
out_channels=128,
|
223 |
+
kernel_size=(3, 3),
|
224 |
+
downsample=(2, 2),
|
225 |
+
activation=activation,
|
226 |
+
momentum=momentum,
|
227 |
+
)
|
228 |
+
self.encoder_block4 = EncoderBlockRes4B(
|
229 |
+
in_channels=128,
|
230 |
+
out_channels=256,
|
231 |
+
kernel_size=(3, 3),
|
232 |
+
downsample=(2, 2),
|
233 |
+
activation=activation,
|
234 |
+
momentum=momentum,
|
235 |
+
)
|
236 |
+
self.encoder_block5 = EncoderBlockRes4B(
|
237 |
+
in_channels=256,
|
238 |
+
out_channels=384,
|
239 |
+
kernel_size=(3, 3),
|
240 |
+
downsample=(2, 2),
|
241 |
+
activation=activation,
|
242 |
+
momentum=momentum,
|
243 |
+
)
|
244 |
+
self.encoder_block6 = EncoderBlockRes4B(
|
245 |
+
in_channels=384,
|
246 |
+
out_channels=384,
|
247 |
+
kernel_size=(3, 3),
|
248 |
+
downsample=(1, 2),
|
249 |
+
activation=activation,
|
250 |
+
momentum=momentum,
|
251 |
+
)
|
252 |
+
self.conv_block7a = EncoderBlockRes4B(
|
253 |
+
in_channels=384,
|
254 |
+
out_channels=384,
|
255 |
+
kernel_size=(3, 3),
|
256 |
+
downsample=(1, 1),
|
257 |
+
activation=activation,
|
258 |
+
momentum=momentum,
|
259 |
+
)
|
260 |
+
self.conv_block7b = EncoderBlockRes4B(
|
261 |
+
in_channels=384,
|
262 |
+
out_channels=384,
|
263 |
+
kernel_size=(3, 3),
|
264 |
+
downsample=(1, 1),
|
265 |
+
activation=activation,
|
266 |
+
momentum=momentum,
|
267 |
+
)
|
268 |
+
self.conv_block7c = EncoderBlockRes4B(
|
269 |
+
in_channels=384,
|
270 |
+
out_channels=384,
|
271 |
+
kernel_size=(3, 3),
|
272 |
+
downsample=(1, 1),
|
273 |
+
activation=activation,
|
274 |
+
momentum=momentum,
|
275 |
+
)
|
276 |
+
self.conv_block7d = EncoderBlockRes4B(
|
277 |
+
in_channels=384,
|
278 |
+
out_channels=384,
|
279 |
+
kernel_size=(3, 3),
|
280 |
+
downsample=(1, 1),
|
281 |
+
activation=activation,
|
282 |
+
momentum=momentum,
|
283 |
+
)
|
284 |
+
self.decoder_block1 = DecoderBlockRes4B(
|
285 |
+
in_channels=384,
|
286 |
+
out_channels=384,
|
287 |
+
kernel_size=(3, 3),
|
288 |
+
upsample=(1, 2),
|
289 |
+
activation=activation,
|
290 |
+
momentum=momentum,
|
291 |
+
)
|
292 |
+
self.decoder_block2 = DecoderBlockRes4B(
|
293 |
+
in_channels=384,
|
294 |
+
out_channels=384,
|
295 |
+
kernel_size=(3, 3),
|
296 |
+
upsample=(2, 2),
|
297 |
+
activation=activation,
|
298 |
+
momentum=momentum,
|
299 |
+
)
|
300 |
+
self.decoder_block3 = DecoderBlockRes4B(
|
301 |
+
in_channels=384,
|
302 |
+
out_channels=256,
|
303 |
+
kernel_size=(3, 3),
|
304 |
+
upsample=(2, 2),
|
305 |
+
activation=activation,
|
306 |
+
momentum=momentum,
|
307 |
+
)
|
308 |
+
self.decoder_block4 = DecoderBlockRes4B(
|
309 |
+
in_channels=256,
|
310 |
+
out_channels=128,
|
311 |
+
kernel_size=(3, 3),
|
312 |
+
upsample=(2, 2),
|
313 |
+
activation=activation,
|
314 |
+
momentum=momentum,
|
315 |
+
)
|
316 |
+
self.decoder_block5 = DecoderBlockRes4B(
|
317 |
+
in_channels=128,
|
318 |
+
out_channels=64,
|
319 |
+
kernel_size=(3, 3),
|
320 |
+
upsample=(2, 2),
|
321 |
+
activation=activation,
|
322 |
+
momentum=momentum,
|
323 |
+
)
|
324 |
+
self.decoder_block6 = DecoderBlockRes4B(
|
325 |
+
in_channels=64,
|
326 |
+
out_channels=32,
|
327 |
+
kernel_size=(3, 3),
|
328 |
+
upsample=(2, 2),
|
329 |
+
activation=activation,
|
330 |
+
momentum=momentum,
|
331 |
+
)
|
332 |
+
|
333 |
+
self.after_conv_block1 = EncoderBlockRes4B(
|
334 |
+
in_channels=32,
|
335 |
+
out_channels=32,
|
336 |
+
kernel_size=(3, 3),
|
337 |
+
downsample=(1, 1),
|
338 |
+
activation=activation,
|
339 |
+
momentum=momentum,
|
340 |
+
)
|
341 |
+
|
342 |
+
self.after_conv2 = nn.Conv2d(
|
343 |
+
in_channels=32,
|
344 |
+
out_channels=target_sources_num
|
345 |
+
* input_channels
|
346 |
+
* self.K
|
347 |
+
* self.subbands_num,
|
348 |
+
kernel_size=(1, 1),
|
349 |
+
stride=(1, 1),
|
350 |
+
padding=(0, 0),
|
351 |
+
bias=True,
|
352 |
+
)
|
353 |
+
|
354 |
+
self.init_weights()
|
355 |
+
|
356 |
+
def init_weights(self):
|
357 |
+
init_bn(self.bn0)
|
358 |
+
init_layer(self.after_conv2)
|
359 |
+
|
360 |
+
def feature_maps_to_wav(
|
361 |
+
self,
|
362 |
+
input_tensor: torch.Tensor,
|
363 |
+
sp: torch.Tensor,
|
364 |
+
sin_in: torch.Tensor,
|
365 |
+
cos_in: torch.Tensor,
|
366 |
+
audio_length: int,
|
367 |
+
) -> torch.Tensor:
|
368 |
+
r"""Convert feature maps to waveform.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
|
372 |
+
sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
373 |
+
sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
374 |
+
cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
375 |
+
|
376 |
+
Outputs:
|
377 |
+
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
|
378 |
+
"""
|
379 |
+
batch_size, _, time_steps, freq_bins = input_tensor.shape
|
380 |
+
|
381 |
+
x = input_tensor.reshape(
|
382 |
+
batch_size,
|
383 |
+
self.target_sources_num,
|
384 |
+
self.input_channels,
|
385 |
+
self.K,
|
386 |
+
time_steps,
|
387 |
+
freq_bins,
|
388 |
+
)
|
389 |
+
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
|
390 |
+
|
391 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
392 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
393 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
394 |
+
linear_mag = torch.tanh(x[:, :, :, 3, :, :])
|
395 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
396 |
+
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
397 |
+
|
398 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
399 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
400 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
401 |
+
out_cos = (
|
402 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
403 |
+
)
|
404 |
+
out_sin = (
|
405 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
406 |
+
)
|
407 |
+
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
408 |
+
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
409 |
+
|
410 |
+
# Calculate |Y|.
|
411 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
|
412 |
+
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
413 |
+
|
414 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
415 |
+
out_real = out_mag * out_cos
|
416 |
+
out_imag = out_mag * out_sin
|
417 |
+
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
418 |
+
|
419 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
420 |
+
shape = (
|
421 |
+
batch_size * self.target_sources_num * self.input_channels,
|
422 |
+
1,
|
423 |
+
time_steps,
|
424 |
+
freq_bins,
|
425 |
+
)
|
426 |
+
out_real = out_real.reshape(shape)
|
427 |
+
out_imag = out_imag.reshape(shape)
|
428 |
+
|
429 |
+
# ISTFT.
|
430 |
+
x = self.istft(out_real, out_imag, audio_length)
|
431 |
+
# (batch_size * target_sources_num * input_channels, segments_num)
|
432 |
+
|
433 |
+
# Reshape.
|
434 |
+
waveform = x.reshape(
|
435 |
+
batch_size, self.target_sources_num * self.input_channels, audio_length
|
436 |
+
)
|
437 |
+
# (batch_size, target_sources_num * input_channels, segments_num)
|
438 |
+
|
439 |
+
return waveform
|
440 |
+
|
441 |
+
def forward(self, input_dict):
|
442 |
+
r"""Forward data into the module.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
input_dict: dict, e.g., {
|
446 |
+
waveform: (batch_size, input_channels, segment_samples),
|
447 |
+
...,
|
448 |
+
}
|
449 |
+
|
450 |
+
Outputs:
|
451 |
+
output_dict: dict, e.g., {
|
452 |
+
'waveform': (batch_size, input_channels, segment_samples),
|
453 |
+
...,
|
454 |
+
}
|
455 |
+
"""
|
456 |
+
mixtures = input_dict['waveform']
|
457 |
+
# (batch_size, input_channels, segment_samples)
|
458 |
+
|
459 |
+
subband_x = self.pqmf.analysis(mixtures)
|
460 |
+
# subband_x: (batch_size, input_channels * subbands_num, segment_samples)
|
461 |
+
|
462 |
+
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
|
463 |
+
# mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)
|
464 |
+
|
465 |
+
# Batch normalize on individual frequency bins.
|
466 |
+
x = mag.transpose(1, 3)
|
467 |
+
x = self.bn0(x)
|
468 |
+
x = x.transpose(1, 3)
|
469 |
+
# (batch_size, input_channels * subbands_num, time_steps, freq_bins)
|
470 |
+
|
471 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
472 |
+
origin_len = x.shape[2]
|
473 |
+
pad_len = (
|
474 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
475 |
+
- origin_len
|
476 |
+
)
|
477 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
478 |
+
# x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
|
479 |
+
|
480 |
+
# Let frequency bins be evenly divided by 2, e.g., 257 -> 256
|
481 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
|
482 |
+
# x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
|
483 |
+
|
484 |
+
# UNet
|
485 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
|
486 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
|
487 |
+
(x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
|
488 |
+
(x4_pool, x4) = self.encoder_block4(
|
489 |
+
x3_pool
|
490 |
+
) # x4_pool: (bs, 256, T / 16, F / 16)
|
491 |
+
(x5_pool, x5) = self.encoder_block5(
|
492 |
+
x4_pool
|
493 |
+
) # x5_pool: (bs, 384, T / 32, F / 32)
|
494 |
+
(x6_pool, x6) = self.encoder_block6(
|
495 |
+
x5_pool
|
496 |
+
) # x6_pool: (bs, 384, T / 32, F / 64)
|
497 |
+
(x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
|
498 |
+
(x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
|
499 |
+
(x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
|
500 |
+
(x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
|
501 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
|
502 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
|
503 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
|
504 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
|
505 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
|
506 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
|
507 |
+
(x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
|
508 |
+
|
509 |
+
x = self.after_conv2(x)
|
510 |
+
# (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
|
511 |
+
|
512 |
+
# Recover shape
|
513 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 256 -> 257.
|
514 |
+
|
515 |
+
x = x[:, :, 0:origin_len, :]
|
516 |
+
# (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
|
517 |
+
|
518 |
+
audio_length = subband_x.shape[2]
|
519 |
+
|
520 |
+
# Recover each subband spectrograms to subband waveforms. Then synthesis
|
521 |
+
# the subband waveforms to a waveform.
|
522 |
+
C1 = x.shape[1] // self.subbands_num
|
523 |
+
C2 = mag.shape[1] // self.subbands_num
|
524 |
+
|
525 |
+
separated_subband_audio = torch.cat(
|
526 |
+
[
|
527 |
+
self.feature_maps_to_wav(
|
528 |
+
input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
|
529 |
+
sp=mag[:, j * C2 : (j + 1) * C2, :, :],
|
530 |
+
sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
|
531 |
+
cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
|
532 |
+
audio_length=audio_length,
|
533 |
+
)
|
534 |
+
for j in range(self.subbands_num)
|
535 |
+
],
|
536 |
+
dim=1,
|
537 |
+
)
|
538 |
+
# (batch_size, subbands_num * target_sources_num * input_channles, segment_samples)
|
539 |
+
|
540 |
+
separated_audio = self.pqmf.synthesis(separated_subband_audio)
|
541 |
+
# (batch_size, input_channles, segment_samples)
|
542 |
+
|
543 |
+
output_dict = {'waveform': separated_audio}
|
544 |
+
|
545 |
+
return output_dict
|
bytesep/models/subband_tools/__init__.py
ADDED
File without changes
|
bytesep/models/subband_tools/fDomainHelper.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchlibrosa.stft import STFT, ISTFT, magphase
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from tools.pytorch.modules.pqmf import PQMF
|
6 |
+
|
7 |
+
|
8 |
+
class FDomainHelper(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
window_size=2048,
|
12 |
+
hop_size=441,
|
13 |
+
center=True,
|
14 |
+
pad_mode='reflect',
|
15 |
+
window='hann',
|
16 |
+
freeze_parameters=True,
|
17 |
+
subband=None,
|
18 |
+
root="/Users/admin/Documents/projects/",
|
19 |
+
):
|
20 |
+
super(FDomainHelper, self).__init__()
|
21 |
+
self.subband = subband
|
22 |
+
if self.subband is None:
|
23 |
+
self.stft = STFT(
|
24 |
+
n_fft=window_size,
|
25 |
+
hop_length=hop_size,
|
26 |
+
win_length=window_size,
|
27 |
+
window=window,
|
28 |
+
center=center,
|
29 |
+
pad_mode=pad_mode,
|
30 |
+
freeze_parameters=freeze_parameters,
|
31 |
+
)
|
32 |
+
|
33 |
+
self.istft = ISTFT(
|
34 |
+
n_fft=window_size,
|
35 |
+
hop_length=hop_size,
|
36 |
+
win_length=window_size,
|
37 |
+
window=window,
|
38 |
+
center=center,
|
39 |
+
pad_mode=pad_mode,
|
40 |
+
freeze_parameters=freeze_parameters,
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
self.stft = STFT(
|
44 |
+
n_fft=window_size // self.subband,
|
45 |
+
hop_length=hop_size // self.subband,
|
46 |
+
win_length=window_size // self.subband,
|
47 |
+
window=window,
|
48 |
+
center=center,
|
49 |
+
pad_mode=pad_mode,
|
50 |
+
freeze_parameters=freeze_parameters,
|
51 |
+
)
|
52 |
+
|
53 |
+
self.istft = ISTFT(
|
54 |
+
n_fft=window_size // self.subband,
|
55 |
+
hop_length=hop_size // self.subband,
|
56 |
+
win_length=window_size // self.subband,
|
57 |
+
window=window,
|
58 |
+
center=center,
|
59 |
+
pad_mode=pad_mode,
|
60 |
+
freeze_parameters=freeze_parameters,
|
61 |
+
)
|
62 |
+
|
63 |
+
if subband is not None and root is not None:
|
64 |
+
self.qmf = PQMF(subband, 64, root)
|
65 |
+
|
66 |
+
def complex_spectrogram(self, input, eps=0.0):
|
67 |
+
# [batchsize, samples]
|
68 |
+
# return [batchsize, 2, t-steps, f-bins]
|
69 |
+
real, imag = self.stft(input)
|
70 |
+
return torch.cat([real, imag], dim=1)
|
71 |
+
|
72 |
+
def reverse_complex_spectrogram(self, input, eps=0.0, length=None):
|
73 |
+
# [batchsize, 2[real,imag], t-steps, f-bins]
|
74 |
+
wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length)
|
75 |
+
return wav
|
76 |
+
|
77 |
+
def spectrogram(self, input, eps=0.0):
|
78 |
+
(real, imag) = self.stft(input.float())
|
79 |
+
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
|
80 |
+
|
81 |
+
def spectrogram_phase(self, input, eps=0.0):
|
82 |
+
(real, imag) = self.stft(input.float())
|
83 |
+
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
|
84 |
+
cos = real / mag
|
85 |
+
sin = imag / mag
|
86 |
+
return mag, cos, sin
|
87 |
+
|
88 |
+
def wav_to_spectrogram_phase(self, input, eps=1e-8):
|
89 |
+
"""Waveform to spectrogram.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
input: (batch_size, channels_num, segment_samples)
|
93 |
+
|
94 |
+
Outputs:
|
95 |
+
output: (batch_size, channels_num, time_steps, freq_bins)
|
96 |
+
"""
|
97 |
+
sp_list = []
|
98 |
+
cos_list = []
|
99 |
+
sin_list = []
|
100 |
+
channels_num = input.shape[1]
|
101 |
+
for channel in range(channels_num):
|
102 |
+
mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps)
|
103 |
+
sp_list.append(mag)
|
104 |
+
cos_list.append(cos)
|
105 |
+
sin_list.append(sin)
|
106 |
+
|
107 |
+
sps = torch.cat(sp_list, dim=1)
|
108 |
+
coss = torch.cat(cos_list, dim=1)
|
109 |
+
sins = torch.cat(sin_list, dim=1)
|
110 |
+
return sps, coss, sins
|
111 |
+
|
112 |
+
def spectrogram_phase_to_wav(self, sps, coss, sins, length):
|
113 |
+
channels_num = sps.size()[1]
|
114 |
+
res = []
|
115 |
+
for i in range(channels_num):
|
116 |
+
res.append(
|
117 |
+
self.istft(
|
118 |
+
sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...],
|
119 |
+
sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...],
|
120 |
+
length,
|
121 |
+
)
|
122 |
+
)
|
123 |
+
res[-1] = res[-1].unsqueeze(1)
|
124 |
+
return torch.cat(res, dim=1)
|
125 |
+
|
126 |
+
def wav_to_spectrogram(self, input, eps=1e-8):
|
127 |
+
"""Waveform to spectrogram.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
input: (batch_size,channels_num, segment_samples)
|
131 |
+
|
132 |
+
Outputs:
|
133 |
+
output: (batch_size, channels_num, time_steps, freq_bins)
|
134 |
+
"""
|
135 |
+
sp_list = []
|
136 |
+
channels_num = input.shape[1]
|
137 |
+
for channel in range(channels_num):
|
138 |
+
sp_list.append(self.spectrogram(input[:, channel, :], eps=eps))
|
139 |
+
output = torch.cat(sp_list, dim=1)
|
140 |
+
return output
|
141 |
+
|
142 |
+
def spectrogram_to_wav(self, input, spectrogram, length=None):
|
143 |
+
"""Spectrogram to waveform.
|
144 |
+
Args:
|
145 |
+
input: (batch_size, segment_samples, channels_num)
|
146 |
+
spectrogram: (batch_size, channels_num, time_steps, freq_bins)
|
147 |
+
|
148 |
+
Outputs:
|
149 |
+
output: (batch_size, segment_samples, channels_num)
|
150 |
+
"""
|
151 |
+
channels_num = input.shape[1]
|
152 |
+
wav_list = []
|
153 |
+
for channel in range(channels_num):
|
154 |
+
(real, imag) = self.stft(input[:, channel, :])
|
155 |
+
(_, cos, sin) = magphase(real, imag)
|
156 |
+
wav_list.append(
|
157 |
+
self.istft(
|
158 |
+
spectrogram[:, channel : channel + 1, :, :] * cos,
|
159 |
+
spectrogram[:, channel : channel + 1, :, :] * sin,
|
160 |
+
length,
|
161 |
+
)
|
162 |
+
)
|
163 |
+
|
164 |
+
output = torch.stack(wav_list, dim=1)
|
165 |
+
return output
|
166 |
+
|
167 |
+
# todo the following code is not bug free!
|
168 |
+
def wav_to_complex_spectrogram(self, input, eps=0.0):
|
169 |
+
# [batchsize , channels, samples]
|
170 |
+
# [batchsize, 2[real,imag]*channels, t-steps, f-bins]
|
171 |
+
res = []
|
172 |
+
channels_num = input.shape[1]
|
173 |
+
for channel in range(channels_num):
|
174 |
+
res.append(self.complex_spectrogram(input[:, channel, :], eps=eps))
|
175 |
+
return torch.cat(res, dim=1)
|
176 |
+
|
177 |
+
def complex_spectrogram_to_wav(self, input, eps=0.0, length=None):
|
178 |
+
# [batchsize, 2[real,imag]*channels, t-steps, f-bins]
|
179 |
+
# return [batchsize, channels, samples]
|
180 |
+
channels = input.size()[1] // 2
|
181 |
+
wavs = []
|
182 |
+
for i in range(channels):
|
183 |
+
wavs.append(
|
184 |
+
self.reverse_complex_spectrogram(
|
185 |
+
input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length
|
186 |
+
)
|
187 |
+
)
|
188 |
+
wavs[-1] = wavs[-1].unsqueeze(1)
|
189 |
+
return torch.cat(wavs, dim=1)
|
190 |
+
|
191 |
+
def wav_to_complex_subband_spectrogram(self, input, eps=0.0):
|
192 |
+
# [batchsize, channels, samples]
|
193 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
194 |
+
subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
|
195 |
+
subspec = self.wav_to_complex_spectrogram(subwav)
|
196 |
+
return subspec
|
197 |
+
|
198 |
+
def complex_subband_spectrogram_to_wav(self, input, eps=0.0):
|
199 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
200 |
+
# [batchsize, channels, samples]
|
201 |
+
subwav = self.complex_spectrogram_to_wav(input)
|
202 |
+
data = self.qmf.synthesis(subwav)
|
203 |
+
return data
|
204 |
+
|
205 |
+
def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8):
|
206 |
+
"""
|
207 |
+
:param input:
|
208 |
+
:param eps:
|
209 |
+
:return:
|
210 |
+
loss = torch.nn.L1Loss()
|
211 |
+
model = FDomainHelper(subband=4)
|
212 |
+
data = torch.randn((3,1, 44100*3))
|
213 |
+
|
214 |
+
sps, coss, sins = model.wav_to_mag_phase_subband_spectrogram(data)
|
215 |
+
wav = model.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4)
|
216 |
+
|
217 |
+
print(loss(data,wav))
|
218 |
+
print(torch.max(torch.abs(data-wav)))
|
219 |
+
|
220 |
+
"""
|
221 |
+
# [batchsize, channels, samples]
|
222 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
223 |
+
subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
|
224 |
+
sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps)
|
225 |
+
return sps, coss, sins
|
226 |
+
|
227 |
+
def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0):
|
228 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
229 |
+
# [batchsize, channels, samples]
|
230 |
+
subwav = self.spectrogram_phase_to_wav(sps, coss, sins, length)
|
231 |
+
data = self.qmf.synthesis(subwav)
|
232 |
+
return data
|
233 |
+
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
# from thop import profile
|
237 |
+
# from thop import clever_format
|
238 |
+
# from tools.file.wav import *
|
239 |
+
# import time
|
240 |
+
#
|
241 |
+
# wav = torch.randn((1,2,44100))
|
242 |
+
# model = FDomainHelper()
|
243 |
+
|
244 |
+
from tools.file.wav import *
|
245 |
+
|
246 |
+
loss = torch.nn.L1Loss()
|
247 |
+
model = FDomainHelper()
|
248 |
+
data = torch.randn((3, 1, 44100 * 5))
|
249 |
+
|
250 |
+
sps = model.wav_to_complex_spectrogram(data)
|
251 |
+
print(sps.size())
|
252 |
+
wav = model.complex_spectrogram_to_wav(sps, 44100 * 5)
|
253 |
+
|
254 |
+
print(loss(data, wav))
|
255 |
+
print(torch.max(torch.abs(data - wav)))
|
bytesep/models/subband_tools/filters/f_4_64.mat
ADDED
Binary file (2.19 kB). View file
|
|
bytesep/models/subband_tools/filters/h_4_64.mat
ADDED
Binary file (2.19 kB). View file
|
|
bytesep/models/subband_tools/pqmf.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : subband_util.py
|
3 |
+
@Contact : liu.8948@buckeyemail.osu.edu
|
4 |
+
@License : (C)Copyright 2020-2021
|
5 |
+
@Modify Time @Author @Version @Desciption
|
6 |
+
------------ ------- -------- -----------
|
7 |
+
2020/4/3 4:54 PM Haohe Liu 1.0 None
|
8 |
+
'''
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.nn as nn
|
13 |
+
import numpy as np
|
14 |
+
import os.path as op
|
15 |
+
from scipy.io import loadmat
|
16 |
+
|
17 |
+
|
18 |
+
def load_mat2numpy(fname=""):
|
19 |
+
'''
|
20 |
+
Args:
|
21 |
+
fname: pth to mat
|
22 |
+
type:
|
23 |
+
Returns: dic object
|
24 |
+
'''
|
25 |
+
if len(fname) == 0:
|
26 |
+
return None
|
27 |
+
else:
|
28 |
+
return loadmat(fname)
|
29 |
+
|
30 |
+
|
31 |
+
class PQMF(nn.Module):
|
32 |
+
def __init__(self, N, M, project_root):
|
33 |
+
super().__init__()
|
34 |
+
self.N = N # nsubband
|
35 |
+
self.M = M # nfilter
|
36 |
+
try:
|
37 |
+
assert (N, M) in [(8, 64), (4, 64), (2, 64)]
|
38 |
+
except:
|
39 |
+
print("Warning:", N, "subbandand ", M, " filter is not supported")
|
40 |
+
self.pad_samples = 64
|
41 |
+
self.name = str(N) + "_" + str(M) + ".mat"
|
42 |
+
self.ana_conv_filter = nn.Conv1d(
|
43 |
+
1, out_channels=N, kernel_size=M, stride=N, bias=False
|
44 |
+
)
|
45 |
+
data = load_mat2numpy(op.join(project_root, "f_" + self.name))
|
46 |
+
data = data['f'].astype(np.float32) / N
|
47 |
+
data = np.flipud(data.T).T
|
48 |
+
data = np.reshape(data, (N, 1, M)).copy()
|
49 |
+
dict_new = self.ana_conv_filter.state_dict().copy()
|
50 |
+
dict_new['weight'] = torch.from_numpy(data)
|
51 |
+
self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
|
52 |
+
self.ana_conv_filter.load_state_dict(dict_new)
|
53 |
+
|
54 |
+
self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
|
55 |
+
self.syn_conv_filter = nn.Conv1d(
|
56 |
+
N, out_channels=N, kernel_size=M // N, stride=1, bias=False
|
57 |
+
)
|
58 |
+
gk = load_mat2numpy(op.join(project_root, "h_" + self.name))
|
59 |
+
gk = gk['h'].astype(np.float32)
|
60 |
+
gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N
|
61 |
+
gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
|
62 |
+
dict_new = self.syn_conv_filter.state_dict().copy()
|
63 |
+
dict_new['weight'] = torch.from_numpy(gk)
|
64 |
+
self.syn_conv_filter.load_state_dict(dict_new)
|
65 |
+
|
66 |
+
for param in self.parameters():
|
67 |
+
param.requires_grad = False
|
68 |
+
|
69 |
+
def __analysis_channel(self, inputs):
|
70 |
+
return self.ana_conv_filter(self.ana_pad(inputs))
|
71 |
+
|
72 |
+
def __systhesis_channel(self, inputs):
|
73 |
+
ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1)
|
74 |
+
return torch.reshape(ret, (ret.shape[0], 1, -1))
|
75 |
+
|
76 |
+
def analysis(self, inputs):
|
77 |
+
'''
|
78 |
+
:param inputs: [batchsize,channel,raw_wav],value:[0,1]
|
79 |
+
:return:
|
80 |
+
'''
|
81 |
+
inputs = F.pad(inputs, ((0, self.pad_samples)))
|
82 |
+
ret = None
|
83 |
+
for i in range(inputs.size()[1]): # channels
|
84 |
+
if ret is None:
|
85 |
+
ret = self.__analysis_channel(inputs[:, i : i + 1, :])
|
86 |
+
else:
|
87 |
+
ret = torch.cat(
|
88 |
+
(ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1
|
89 |
+
)
|
90 |
+
return ret
|
91 |
+
|
92 |
+
def synthesis(self, data):
|
93 |
+
'''
|
94 |
+
:param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1]
|
95 |
+
:return:
|
96 |
+
'''
|
97 |
+
ret = None
|
98 |
+
# data = F.pad(data,((0,self.pad_samples//self.N)))
|
99 |
+
for i in range(data.size()[1]): # channels
|
100 |
+
if i % self.N == 0:
|
101 |
+
if ret is None:
|
102 |
+
ret = self.__systhesis_channel(data[:, i : i + self.N, :])
|
103 |
+
else:
|
104 |
+
new = self.__systhesis_channel(data[:, i : i + self.N, :])
|
105 |
+
ret = torch.cat((ret, new), dim=1)
|
106 |
+
ret = ret[..., : -self.pad_samples]
|
107 |
+
return ret
|
108 |
+
|
109 |
+
def forward(self, inputs):
|
110 |
+
return self.ana_conv_filter(self.ana_pad(inputs))
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
import torch
|
115 |
+
import numpy as np
|
116 |
+
import matplotlib.pyplot as plt
|
117 |
+
from tools.file.wav import *
|
118 |
+
|
119 |
+
pqmf = PQMF(N=4, M=64, project_root="/Users/admin/Documents/projects")
|
120 |
+
|
121 |
+
rs = np.random.RandomState(0)
|
122 |
+
x = torch.tensor(rs.rand(4, 2, 32000), dtype=torch.float32)
|
123 |
+
|
124 |
+
a1 = pqmf.analysis(x)
|
125 |
+
a2 = pqmf.synthesis(a1)
|
126 |
+
|
127 |
+
print(a2.size(), x.size())
|
128 |
+
|
129 |
+
plt.subplot(211)
|
130 |
+
plt.plot(x[0, 0, -500:])
|
131 |
+
plt.subplot(212)
|
132 |
+
plt.plot(a2[0, 0, -500:])
|
133 |
+
plt.plot(x[0, 0, -500:] - a2[0, 0, -500:])
|
134 |
+
plt.show()
|
135 |
+
|
136 |
+
print(torch.sum(torch.abs(x[...] - a2[...])))
|
bytesep/models/unet.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Dict, List, NoReturn, Tuple
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.optim as optim
|
11 |
+
from torch.optim.lr_scheduler import LambdaLR
|
12 |
+
from torchlibrosa.stft import ISTFT, STFT, magphase
|
13 |
+
|
14 |
+
from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer
|
15 |
+
|
16 |
+
|
17 |
+
class ConvBlock(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_channels: int,
|
21 |
+
out_channels: int,
|
22 |
+
kernel_size: Tuple,
|
23 |
+
activation: str,
|
24 |
+
momentum: float,
|
25 |
+
):
|
26 |
+
r"""Convolutional block."""
|
27 |
+
super(ConvBlock, self).__init__()
|
28 |
+
|
29 |
+
self.activation = activation
|
30 |
+
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
31 |
+
|
32 |
+
self.conv1 = nn.Conv2d(
|
33 |
+
in_channels=in_channels,
|
34 |
+
out_channels=out_channels,
|
35 |
+
kernel_size=kernel_size,
|
36 |
+
stride=(1, 1),
|
37 |
+
dilation=(1, 1),
|
38 |
+
padding=padding,
|
39 |
+
bias=False,
|
40 |
+
)
|
41 |
+
|
42 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
43 |
+
|
44 |
+
self.conv2 = nn.Conv2d(
|
45 |
+
in_channels=out_channels,
|
46 |
+
out_channels=out_channels,
|
47 |
+
kernel_size=kernel_size,
|
48 |
+
stride=(1, 1),
|
49 |
+
dilation=(1, 1),
|
50 |
+
padding=padding,
|
51 |
+
bias=False,
|
52 |
+
)
|
53 |
+
|
54 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
55 |
+
|
56 |
+
self.init_weights()
|
57 |
+
|
58 |
+
def init_weights(self) -> NoReturn:
|
59 |
+
r"""Initialize weights."""
|
60 |
+
init_layer(self.conv1)
|
61 |
+
init_layer(self.conv2)
|
62 |
+
init_bn(self.bn1)
|
63 |
+
init_bn(self.bn2)
|
64 |
+
|
65 |
+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
66 |
+
r"""Forward data into the module.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
|
73 |
+
"""
|
74 |
+
x = act(self.bn1(self.conv1(input_tensor)), self.activation)
|
75 |
+
x = act(self.bn2(self.conv2(x)), self.activation)
|
76 |
+
output_tensor = x
|
77 |
+
|
78 |
+
return output_tensor
|
79 |
+
|
80 |
+
|
81 |
+
class EncoderBlock(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
in_channels: int,
|
85 |
+
out_channels: int,
|
86 |
+
kernel_size: Tuple,
|
87 |
+
downsample: Tuple,
|
88 |
+
activation: str,
|
89 |
+
momentum: float,
|
90 |
+
):
|
91 |
+
r"""Encoder block."""
|
92 |
+
super(EncoderBlock, self).__init__()
|
93 |
+
|
94 |
+
self.conv_block = ConvBlock(
|
95 |
+
in_channels, out_channels, kernel_size, activation, momentum
|
96 |
+
)
|
97 |
+
self.downsample = downsample
|
98 |
+
|
99 |
+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
100 |
+
r"""Forward data into the module.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins)
|
107 |
+
encoder: (batch_size, out_feature_maps, time_steps, freq_bins)
|
108 |
+
"""
|
109 |
+
encoder_tensor = self.conv_block(input_tensor)
|
110 |
+
# encoder: (batch_size, out_feature_maps, time_steps, freq_bins)
|
111 |
+
|
112 |
+
encoder_pool = F.avg_pool2d(encoder_tensor, kernel_size=self.downsample)
|
113 |
+
# encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins)
|
114 |
+
|
115 |
+
return encoder_pool, encoder_tensor
|
116 |
+
|
117 |
+
|
118 |
+
class DecoderBlock(nn.Module):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
in_channels: int,
|
122 |
+
out_channels: int,
|
123 |
+
kernel_size: Tuple,
|
124 |
+
upsample: Tuple,
|
125 |
+
activation: str,
|
126 |
+
momentum: float,
|
127 |
+
):
|
128 |
+
r"""Decoder block."""
|
129 |
+
super(DecoderBlock, self).__init__()
|
130 |
+
|
131 |
+
self.kernel_size = kernel_size
|
132 |
+
self.stride = upsample
|
133 |
+
self.activation = activation
|
134 |
+
|
135 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
136 |
+
in_channels=in_channels,
|
137 |
+
out_channels=out_channels,
|
138 |
+
kernel_size=self.stride,
|
139 |
+
stride=self.stride,
|
140 |
+
padding=(0, 0),
|
141 |
+
bias=False,
|
142 |
+
dilation=(1, 1),
|
143 |
+
)
|
144 |
+
|
145 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
146 |
+
|
147 |
+
self.conv_block2 = ConvBlock(
|
148 |
+
out_channels * 2, out_channels, kernel_size, activation, momentum
|
149 |
+
)
|
150 |
+
|
151 |
+
self.init_weights()
|
152 |
+
|
153 |
+
def init_weights(self):
|
154 |
+
r"""Initialize weights."""
|
155 |
+
init_layer(self.conv1)
|
156 |
+
init_bn(self.bn1)
|
157 |
+
|
158 |
+
def forward(
|
159 |
+
self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor
|
160 |
+
) -> torch.Tensor:
|
161 |
+
r"""Forward data into the module.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
torch_tensor: (batch_size, in_feature_maps, downsampled_time_steps, downsampled_freq_bins)
|
165 |
+
concat_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
|
169 |
+
"""
|
170 |
+
x = act(self.bn1(self.conv1(input_tensor)), self.activation)
|
171 |
+
# (batch_size, in_feature_maps, time_steps, freq_bins)
|
172 |
+
|
173 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
174 |
+
# (batch_size, in_feature_maps * 2, time_steps, freq_bins)
|
175 |
+
|
176 |
+
output_tensor = self.conv_block2(x)
|
177 |
+
# output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
|
178 |
+
|
179 |
+
return output_tensor
|
180 |
+
|
181 |
+
|
182 |
+
class UNet(nn.Module, Base):
|
183 |
+
def __init__(self, input_channels: int, target_sources_num: int):
|
184 |
+
r"""UNet."""
|
185 |
+
super(UNet, self).__init__()
|
186 |
+
|
187 |
+
self.input_channels = input_channels
|
188 |
+
self.target_sources_num = target_sources_num
|
189 |
+
|
190 |
+
window_size = 2048
|
191 |
+
hop_size = 441
|
192 |
+
center = True
|
193 |
+
pad_mode = "reflect"
|
194 |
+
window = "hann"
|
195 |
+
activation = "leaky_relu"
|
196 |
+
momentum = 0.01
|
197 |
+
|
198 |
+
self.subbands_num = 1
|
199 |
+
|
200 |
+
assert (
|
201 |
+
self.subbands_num == 1
|
202 |
+
), "Using subbands_num > 1 on spectrogram \
|
203 |
+
will lead to unexpected performance sometimes. Suggest to use \
|
204 |
+
subband method on waveform."
|
205 |
+
|
206 |
+
self.K = 3 # outputs: |M|, cos∠M, sin∠M
|
207 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
|
208 |
+
|
209 |
+
self.stft = STFT(
|
210 |
+
n_fft=window_size,
|
211 |
+
hop_length=hop_size,
|
212 |
+
win_length=window_size,
|
213 |
+
window=window,
|
214 |
+
center=center,
|
215 |
+
pad_mode=pad_mode,
|
216 |
+
freeze_parameters=True,
|
217 |
+
)
|
218 |
+
|
219 |
+
self.istft = ISTFT(
|
220 |
+
n_fft=window_size,
|
221 |
+
hop_length=hop_size,
|
222 |
+
win_length=window_size,
|
223 |
+
window=window,
|
224 |
+
center=center,
|
225 |
+
pad_mode=pad_mode,
|
226 |
+
freeze_parameters=True,
|
227 |
+
)
|
228 |
+
|
229 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
230 |
+
|
231 |
+
self.subband = Subband(subbands_num=self.subbands_num)
|
232 |
+
|
233 |
+
self.encoder_block1 = EncoderBlock(
|
234 |
+
in_channels=input_channels * self.subbands_num,
|
235 |
+
out_channels=32,
|
236 |
+
kernel_size=(3, 3),
|
237 |
+
downsample=(2, 2),
|
238 |
+
activation=activation,
|
239 |
+
momentum=momentum,
|
240 |
+
)
|
241 |
+
self.encoder_block2 = EncoderBlock(
|
242 |
+
in_channels=32,
|
243 |
+
out_channels=64,
|
244 |
+
kernel_size=(3, 3),
|
245 |
+
downsample=(2, 2),
|
246 |
+
activation=activation,
|
247 |
+
momentum=momentum,
|
248 |
+
)
|
249 |
+
self.encoder_block3 = EncoderBlock(
|
250 |
+
in_channels=64,
|
251 |
+
out_channels=128,
|
252 |
+
kernel_size=(3, 3),
|
253 |
+
downsample=(2, 2),
|
254 |
+
activation=activation,
|
255 |
+
momentum=momentum,
|
256 |
+
)
|
257 |
+
self.encoder_block4 = EncoderBlock(
|
258 |
+
in_channels=128,
|
259 |
+
out_channels=256,
|
260 |
+
kernel_size=(3, 3),
|
261 |
+
downsample=(2, 2),
|
262 |
+
activation=activation,
|
263 |
+
momentum=momentum,
|
264 |
+
)
|
265 |
+
self.encoder_block5 = EncoderBlock(
|
266 |
+
in_channels=256,
|
267 |
+
out_channels=384,
|
268 |
+
kernel_size=(3, 3),
|
269 |
+
downsample=(2, 2),
|
270 |
+
activation=activation,
|
271 |
+
momentum=momentum,
|
272 |
+
)
|
273 |
+
self.encoder_block6 = EncoderBlock(
|
274 |
+
in_channels=384,
|
275 |
+
out_channels=384,
|
276 |
+
kernel_size=(3, 3),
|
277 |
+
downsample=(2, 2),
|
278 |
+
activation=activation,
|
279 |
+
momentum=momentum,
|
280 |
+
)
|
281 |
+
self.conv_block7 = ConvBlock(
|
282 |
+
in_channels=384,
|
283 |
+
out_channels=384,
|
284 |
+
kernel_size=(3, 3),
|
285 |
+
activation=activation,
|
286 |
+
momentum=momentum,
|
287 |
+
)
|
288 |
+
self.decoder_block1 = DecoderBlock(
|
289 |
+
in_channels=384,
|
290 |
+
out_channels=384,
|
291 |
+
kernel_size=(3, 3),
|
292 |
+
upsample=(2, 2),
|
293 |
+
activation=activation,
|
294 |
+
momentum=momentum,
|
295 |
+
)
|
296 |
+
self.decoder_block2 = DecoderBlock(
|
297 |
+
in_channels=384,
|
298 |
+
out_channels=384,
|
299 |
+
kernel_size=(3, 3),
|
300 |
+
upsample=(2, 2),
|
301 |
+
activation=activation,
|
302 |
+
momentum=momentum,
|
303 |
+
)
|
304 |
+
self.decoder_block3 = DecoderBlock(
|
305 |
+
in_channels=384,
|
306 |
+
out_channels=256,
|
307 |
+
kernel_size=(3, 3),
|
308 |
+
upsample=(2, 2),
|
309 |
+
activation=activation,
|
310 |
+
momentum=momentum,
|
311 |
+
)
|
312 |
+
self.decoder_block4 = DecoderBlock(
|
313 |
+
in_channels=256,
|
314 |
+
out_channels=128,
|
315 |
+
kernel_size=(3, 3),
|
316 |
+
upsample=(2, 2),
|
317 |
+
activation=activation,
|
318 |
+
momentum=momentum,
|
319 |
+
)
|
320 |
+
self.decoder_block5 = DecoderBlock(
|
321 |
+
in_channels=128,
|
322 |
+
out_channels=64,
|
323 |
+
kernel_size=(3, 3),
|
324 |
+
upsample=(2, 2),
|
325 |
+
activation=activation,
|
326 |
+
momentum=momentum,
|
327 |
+
)
|
328 |
+
|
329 |
+
self.decoder_block6 = DecoderBlock(
|
330 |
+
in_channels=64,
|
331 |
+
out_channels=32,
|
332 |
+
kernel_size=(3, 3),
|
333 |
+
upsample=(2, 2),
|
334 |
+
activation=activation,
|
335 |
+
momentum=momentum,
|
336 |
+
)
|
337 |
+
|
338 |
+
self.after_conv_block1 = ConvBlock(
|
339 |
+
in_channels=32,
|
340 |
+
out_channels=32,
|
341 |
+
kernel_size=(3, 3),
|
342 |
+
activation=activation,
|
343 |
+
momentum=momentum,
|
344 |
+
)
|
345 |
+
|
346 |
+
self.after_conv2 = nn.Conv2d(
|
347 |
+
in_channels=32,
|
348 |
+
out_channels=target_sources_num
|
349 |
+
* input_channels
|
350 |
+
* self.K
|
351 |
+
* self.subbands_num,
|
352 |
+
kernel_size=(1, 1),
|
353 |
+
stride=(1, 1),
|
354 |
+
padding=(0, 0),
|
355 |
+
bias=True,
|
356 |
+
)
|
357 |
+
|
358 |
+
self.init_weights()
|
359 |
+
|
360 |
+
def init_weights(self):
|
361 |
+
r"""Initialize weights."""
|
362 |
+
init_bn(self.bn0)
|
363 |
+
init_layer(self.after_conv2)
|
364 |
+
|
365 |
+
def feature_maps_to_wav(
|
366 |
+
self,
|
367 |
+
input_tensor: torch.Tensor,
|
368 |
+
sp: torch.Tensor,
|
369 |
+
sin_in: torch.Tensor,
|
370 |
+
cos_in: torch.Tensor,
|
371 |
+
audio_length: int,
|
372 |
+
) -> torch.Tensor:
|
373 |
+
r"""Convert feature maps to waveform.
|
374 |
+
|
375 |
+
Args:
|
376 |
+
input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
|
377 |
+
sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
378 |
+
sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
379 |
+
cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
380 |
+
|
381 |
+
Outputs:
|
382 |
+
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
|
383 |
+
"""
|
384 |
+
batch_size, _, time_steps, freq_bins = input_tensor.shape
|
385 |
+
|
386 |
+
x = input_tensor.reshape(
|
387 |
+
batch_size,
|
388 |
+
self.target_sources_num,
|
389 |
+
self.input_channels,
|
390 |
+
self.K,
|
391 |
+
time_steps,
|
392 |
+
freq_bins,
|
393 |
+
)
|
394 |
+
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
|
395 |
+
|
396 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
397 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
398 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
399 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
400 |
+
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
401 |
+
|
402 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
403 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
404 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
405 |
+
out_cos = (
|
406 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
407 |
+
)
|
408 |
+
out_sin = (
|
409 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
410 |
+
)
|
411 |
+
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
412 |
+
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
413 |
+
|
414 |
+
# Calculate |Y|.
|
415 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
|
416 |
+
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
417 |
+
|
418 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
419 |
+
out_real = out_mag * out_cos
|
420 |
+
out_imag = out_mag * out_sin
|
421 |
+
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
422 |
+
|
423 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
424 |
+
shape = (
|
425 |
+
batch_size * self.target_sources_num * self.input_channels,
|
426 |
+
1,
|
427 |
+
time_steps,
|
428 |
+
freq_bins,
|
429 |
+
)
|
430 |
+
out_real = out_real.reshape(shape)
|
431 |
+
out_imag = out_imag.reshape(shape)
|
432 |
+
|
433 |
+
# ISTFT.
|
434 |
+
x = self.istft(out_real, out_imag, audio_length)
|
435 |
+
# (batch_size * target_sources_num * input_channels, segments_num)
|
436 |
+
|
437 |
+
# Reshape.
|
438 |
+
waveform = x.reshape(
|
439 |
+
batch_size, self.target_sources_num * self.input_channels, audio_length
|
440 |
+
)
|
441 |
+
# (batch_size, target_sources_num * input_channels, segments_num)
|
442 |
+
|
443 |
+
return waveform
|
444 |
+
|
445 |
+
def forward(self, input_dict: Dict) -> Dict:
|
446 |
+
r"""Forward data into the module.
|
447 |
+
|
448 |
+
Args:
|
449 |
+
input_dict: dict, e.g., {
|
450 |
+
waveform: (batch_size, input_channels, segment_samples),
|
451 |
+
...,
|
452 |
+
}
|
453 |
+
|
454 |
+
Outputs:
|
455 |
+
output_dict: dict, e.g., {
|
456 |
+
'waveform': (batch_size, input_channels, segment_samples),
|
457 |
+
...,
|
458 |
+
}
|
459 |
+
"""
|
460 |
+
mixtures = input_dict['waveform']
|
461 |
+
# (batch_size, input_channels, segment_samples)
|
462 |
+
|
463 |
+
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
|
464 |
+
# mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
|
465 |
+
|
466 |
+
# Batch normalize on individual frequency bins.
|
467 |
+
x = mag.transpose(1, 3)
|
468 |
+
x = self.bn0(x)
|
469 |
+
x = x.transpose(1, 3)
|
470 |
+
# x: (batch_size, input_channels, time_steps, freq_bins)
|
471 |
+
|
472 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
473 |
+
origin_len = x.shape[2]
|
474 |
+
pad_len = (
|
475 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
476 |
+
- origin_len
|
477 |
+
)
|
478 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
479 |
+
# x: (batch_size, input_channels, padded_time_steps, freq_bins)
|
480 |
+
|
481 |
+
# Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024
|
482 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
|
483 |
+
|
484 |
+
if self.subbands_num > 1:
|
485 |
+
x = self.subband.analysis(x)
|
486 |
+
# (bs, input_channels, T, F'), where F' = F // subbands_num
|
487 |
+
|
488 |
+
# UNet
|
489 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2)
|
490 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4)
|
491 |
+
(x3_pool, x3) = self.encoder_block3(
|
492 |
+
x2_pool
|
493 |
+
) # x3_pool: (bs, 128, T / 8, F' / 8)
|
494 |
+
(x4_pool, x4) = self.encoder_block4(
|
495 |
+
x3_pool
|
496 |
+
) # x4_pool: (bs, 256, T / 16, F' / 16)
|
497 |
+
(x5_pool, x5) = self.encoder_block5(
|
498 |
+
x4_pool
|
499 |
+
) # x5_pool: (bs, 384, T / 32, F' / 32)
|
500 |
+
(x6_pool, x6) = self.encoder_block6(
|
501 |
+
x5_pool
|
502 |
+
) # x6_pool: (bs, 384, T / 64, F' / 64)
|
503 |
+
x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64)
|
504 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32)
|
505 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16)
|
506 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8)
|
507 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4)
|
508 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2)
|
509 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F')
|
510 |
+
x = self.after_conv_block1(x12) # (bs, 32, T, F')
|
511 |
+
|
512 |
+
x = self.after_conv2(x)
|
513 |
+
# (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F')
|
514 |
+
|
515 |
+
if self.subbands_num > 1:
|
516 |
+
x = self.subband.synthesis(x)
|
517 |
+
# (batch_size, target_sources_num * input_channles * self.K, T, F)
|
518 |
+
|
519 |
+
# Recover shape
|
520 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
|
521 |
+
|
522 |
+
x = x[:, :, 0:origin_len, :]
|
523 |
+
# (batch_size, target_sources_num * input_channles * self.K, T, F)
|
524 |
+
|
525 |
+
audio_length = mixtures.shape[2]
|
526 |
+
|
527 |
+
separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
|
528 |
+
# separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
|
529 |
+
|
530 |
+
output_dict = {'waveform': separated_audio}
|
531 |
+
|
532 |
+
return output_dict
|
bytesep/models/unet_subbandtime.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchlibrosa.stft import ISTFT, STFT, magphase
|
8 |
+
|
9 |
+
from bytesep.models.pytorch_modules import Base, init_bn, init_layer
|
10 |
+
from bytesep.models.subband_tools.pqmf import PQMF
|
11 |
+
from bytesep.models.unet import ConvBlock, DecoderBlock, EncoderBlock
|
12 |
+
|
13 |
+
|
14 |
+
class UNetSubbandTime(nn.Module, Base):
|
15 |
+
def __init__(self, input_channels: int, target_sources_num: int):
|
16 |
+
r"""Subband waveform UNet."""
|
17 |
+
super(UNetSubbandTime, self).__init__()
|
18 |
+
|
19 |
+
self.input_channels = input_channels
|
20 |
+
self.target_sources_num = target_sources_num
|
21 |
+
|
22 |
+
window_size = 512 # 2048 // 4
|
23 |
+
hop_size = 110 # 441 // 4
|
24 |
+
center = True
|
25 |
+
pad_mode = "reflect"
|
26 |
+
window = "hann"
|
27 |
+
activation = "leaky_relu"
|
28 |
+
momentum = 0.01
|
29 |
+
|
30 |
+
self.subbands_num = 4
|
31 |
+
self.K = 3 # outputs: |M|, cos∠M, sin∠M
|
32 |
+
|
33 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
|
34 |
+
|
35 |
+
self.pqmf = PQMF(
|
36 |
+
N=self.subbands_num,
|
37 |
+
M=64,
|
38 |
+
project_root='bytesep/models/subband_tools/filters',
|
39 |
+
)
|
40 |
+
|
41 |
+
self.stft = STFT(
|
42 |
+
n_fft=window_size,
|
43 |
+
hop_length=hop_size,
|
44 |
+
win_length=window_size,
|
45 |
+
window=window,
|
46 |
+
center=center,
|
47 |
+
pad_mode=pad_mode,
|
48 |
+
freeze_parameters=True,
|
49 |
+
)
|
50 |
+
|
51 |
+
self.istft = ISTFT(
|
52 |
+
n_fft=window_size,
|
53 |
+
hop_length=hop_size,
|
54 |
+
win_length=window_size,
|
55 |
+
window=window,
|
56 |
+
center=center,
|
57 |
+
pad_mode=pad_mode,
|
58 |
+
freeze_parameters=True,
|
59 |
+
)
|
60 |
+
|
61 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
62 |
+
|
63 |
+
self.encoder_block1 = EncoderBlock(
|
64 |
+
in_channels=input_channels * self.subbands_num,
|
65 |
+
out_channels=32,
|
66 |
+
kernel_size=(3, 3),
|
67 |
+
downsample=(2, 2),
|
68 |
+
activation=activation,
|
69 |
+
momentum=momentum,
|
70 |
+
)
|
71 |
+
self.encoder_block2 = EncoderBlock(
|
72 |
+
in_channels=32,
|
73 |
+
out_channels=64,
|
74 |
+
kernel_size=(3, 3),
|
75 |
+
downsample=(2, 2),
|
76 |
+
activation=activation,
|
77 |
+
momentum=momentum,
|
78 |
+
)
|
79 |
+
self.encoder_block3 = EncoderBlock(
|
80 |
+
in_channels=64,
|
81 |
+
out_channels=128,
|
82 |
+
kernel_size=(3, 3),
|
83 |
+
downsample=(2, 2),
|
84 |
+
activation=activation,
|
85 |
+
momentum=momentum,
|
86 |
+
)
|
87 |
+
self.encoder_block4 = EncoderBlock(
|
88 |
+
in_channels=128,
|
89 |
+
out_channels=256,
|
90 |
+
kernel_size=(3, 3),
|
91 |
+
downsample=(2, 2),
|
92 |
+
activation=activation,
|
93 |
+
momentum=momentum,
|
94 |
+
)
|
95 |
+
self.encoder_block5 = EncoderBlock(
|
96 |
+
in_channels=256,
|
97 |
+
out_channels=384,
|
98 |
+
kernel_size=(3, 3),
|
99 |
+
downsample=(2, 2),
|
100 |
+
activation=activation,
|
101 |
+
momentum=momentum,
|
102 |
+
)
|
103 |
+
self.encoder_block6 = EncoderBlock(
|
104 |
+
in_channels=384,
|
105 |
+
out_channels=384,
|
106 |
+
kernel_size=(3, 3),
|
107 |
+
downsample=(2, 2),
|
108 |
+
activation=activation,
|
109 |
+
momentum=momentum,
|
110 |
+
)
|
111 |
+
self.conv_block7 = ConvBlock(
|
112 |
+
in_channels=384,
|
113 |
+
out_channels=384,
|
114 |
+
kernel_size=(3, 3),
|
115 |
+
activation=activation,
|
116 |
+
momentum=momentum,
|
117 |
+
)
|
118 |
+
self.decoder_block1 = DecoderBlock(
|
119 |
+
in_channels=384,
|
120 |
+
out_channels=384,
|
121 |
+
kernel_size=(3, 3),
|
122 |
+
upsample=(2, 2),
|
123 |
+
activation=activation,
|
124 |
+
momentum=momentum,
|
125 |
+
)
|
126 |
+
self.decoder_block2 = DecoderBlock(
|
127 |
+
in_channels=384,
|
128 |
+
out_channels=384,
|
129 |
+
kernel_size=(3, 3),
|
130 |
+
upsample=(2, 2),
|
131 |
+
activation=activation,
|
132 |
+
momentum=momentum,
|
133 |
+
)
|
134 |
+
self.decoder_block3 = DecoderBlock(
|
135 |
+
in_channels=384,
|
136 |
+
out_channels=256,
|
137 |
+
kernel_size=(3, 3),
|
138 |
+
upsample=(2, 2),
|
139 |
+
activation=activation,
|
140 |
+
momentum=momentum,
|
141 |
+
)
|
142 |
+
self.decoder_block4 = DecoderBlock(
|
143 |
+
in_channels=256,
|
144 |
+
out_channels=128,
|
145 |
+
kernel_size=(3, 3),
|
146 |
+
upsample=(2, 2),
|
147 |
+
activation=activation,
|
148 |
+
momentum=momentum,
|
149 |
+
)
|
150 |
+
self.decoder_block5 = DecoderBlock(
|
151 |
+
in_channels=128,
|
152 |
+
out_channels=64,
|
153 |
+
kernel_size=(3, 3),
|
154 |
+
upsample=(2, 2),
|
155 |
+
activation=activation,
|
156 |
+
momentum=momentum,
|
157 |
+
)
|
158 |
+
|
159 |
+
self.decoder_block6 = DecoderBlock(
|
160 |
+
in_channels=64,
|
161 |
+
out_channels=32,
|
162 |
+
kernel_size=(3, 3),
|
163 |
+
upsample=(2, 2),
|
164 |
+
activation=activation,
|
165 |
+
momentum=momentum,
|
166 |
+
)
|
167 |
+
|
168 |
+
self.after_conv_block1 = ConvBlock(
|
169 |
+
in_channels=32,
|
170 |
+
out_channels=32,
|
171 |
+
kernel_size=(3, 3),
|
172 |
+
activation=activation,
|
173 |
+
momentum=momentum,
|
174 |
+
)
|
175 |
+
|
176 |
+
self.after_conv2 = nn.Conv2d(
|
177 |
+
in_channels=32,
|
178 |
+
out_channels=target_sources_num
|
179 |
+
* input_channels
|
180 |
+
* self.K
|
181 |
+
* self.subbands_num,
|
182 |
+
kernel_size=(1, 1),
|
183 |
+
stride=(1, 1),
|
184 |
+
padding=(0, 0),
|
185 |
+
bias=True,
|
186 |
+
)
|
187 |
+
|
188 |
+
self.init_weights()
|
189 |
+
|
190 |
+
def init_weights(self):
|
191 |
+
r"""Initialize weights."""
|
192 |
+
init_bn(self.bn0)
|
193 |
+
init_layer(self.after_conv2)
|
194 |
+
|
195 |
+
def feature_maps_to_wav(
|
196 |
+
self,
|
197 |
+
input_tensor: torch.Tensor,
|
198 |
+
sp: torch.Tensor,
|
199 |
+
sin_in: torch.Tensor,
|
200 |
+
cos_in: torch.Tensor,
|
201 |
+
audio_length: int,
|
202 |
+
) -> torch.Tensor:
|
203 |
+
r"""Convert feature maps to waveform.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
|
207 |
+
sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
208 |
+
sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
209 |
+
cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
210 |
+
|
211 |
+
Outputs:
|
212 |
+
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
|
213 |
+
"""
|
214 |
+
batch_size, _, time_steps, freq_bins = input_tensor.shape
|
215 |
+
|
216 |
+
x = input_tensor.reshape(
|
217 |
+
batch_size,
|
218 |
+
self.target_sources_num,
|
219 |
+
self.input_channels,
|
220 |
+
self.K,
|
221 |
+
time_steps,
|
222 |
+
freq_bins,
|
223 |
+
)
|
224 |
+
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
|
225 |
+
|
226 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
227 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
228 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
229 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
230 |
+
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
231 |
+
|
232 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
233 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
234 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
235 |
+
out_cos = (
|
236 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
237 |
+
)
|
238 |
+
out_sin = (
|
239 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
240 |
+
)
|
241 |
+
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
242 |
+
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
243 |
+
|
244 |
+
# Calculate |Y|.
|
245 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
|
246 |
+
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
247 |
+
|
248 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
249 |
+
out_real = out_mag * out_cos
|
250 |
+
out_imag = out_mag * out_sin
|
251 |
+
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
252 |
+
|
253 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
254 |
+
shape = (
|
255 |
+
batch_size * self.target_sources_num * self.input_channels,
|
256 |
+
1,
|
257 |
+
time_steps,
|
258 |
+
freq_bins,
|
259 |
+
)
|
260 |
+
out_real = out_real.reshape(shape)
|
261 |
+
out_imag = out_imag.reshape(shape)
|
262 |
+
|
263 |
+
# ISTFT.
|
264 |
+
x = self.istft(out_real, out_imag, audio_length)
|
265 |
+
# (batch_size * target_sources_num * input_channels, segments_num)
|
266 |
+
|
267 |
+
# Reshape.
|
268 |
+
waveform = x.reshape(
|
269 |
+
batch_size, self.target_sources_num * self.input_channels, audio_length
|
270 |
+
)
|
271 |
+
# (batch_size, target_sources_num * input_channels, segments_num)
|
272 |
+
|
273 |
+
return waveform
|
274 |
+
|
275 |
+
def forward(self, input_dict: Dict) -> Dict:
|
276 |
+
"""Forward data into the module.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
input_dict: dict, e.g., {
|
280 |
+
waveform: (batch_size, input_channels, segment_samples),
|
281 |
+
...,
|
282 |
+
}
|
283 |
+
|
284 |
+
Outputs:
|
285 |
+
output_dict: dict, e.g., {
|
286 |
+
'waveform': (batch_size, input_channels, segment_samples),
|
287 |
+
...,
|
288 |
+
}
|
289 |
+
"""
|
290 |
+
mixtures = input_dict['waveform']
|
291 |
+
# (batch_size, input_channels, segment_samples)
|
292 |
+
|
293 |
+
if self.subbands_num > 1:
|
294 |
+
subband_x = self.pqmf.analysis(mixtures)
|
295 |
+
# -- subband_x: (batch_size, input_channels * subbands_num, segment_samples)
|
296 |
+
# -- subband_x: (batch_size, subbands_num * input_channels, segment_samples)
|
297 |
+
else:
|
298 |
+
subband_x = mixtures
|
299 |
+
|
300 |
+
# from IPython import embed; embed(using=False); os._exit(0)
|
301 |
+
# import soundfile
|
302 |
+
# soundfile.write(file='_zz.wav', data=subband_x.data.cpu().numpy()[0, 2], samplerate=11025)
|
303 |
+
|
304 |
+
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
|
305 |
+
# mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)
|
306 |
+
|
307 |
+
# Batch normalize on individual frequency bins.
|
308 |
+
x = mag.transpose(1, 3)
|
309 |
+
x = self.bn0(x)
|
310 |
+
x = x.transpose(1, 3)
|
311 |
+
# (batch_size, input_channels * subbands_num, time_steps, freq_bins)
|
312 |
+
|
313 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
314 |
+
origin_len = x.shape[2]
|
315 |
+
pad_len = (
|
316 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
317 |
+
- origin_len
|
318 |
+
)
|
319 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
320 |
+
# x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
|
321 |
+
|
322 |
+
# Let frequency bins be evenly divided by 2, e.g., 257 -> 256
|
323 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
|
324 |
+
# x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
|
325 |
+
|
326 |
+
# UNet
|
327 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2)
|
328 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4)
|
329 |
+
(x3_pool, x3) = self.encoder_block3(
|
330 |
+
x2_pool
|
331 |
+
) # x3_pool: (bs, 128, T / 8, F' / 8)
|
332 |
+
(x4_pool, x4) = self.encoder_block4(
|
333 |
+
x3_pool
|
334 |
+
) # x4_pool: (bs, 256, T / 16, F' / 16)
|
335 |
+
(x5_pool, x5) = self.encoder_block5(
|
336 |
+
x4_pool
|
337 |
+
) # x5_pool: (bs, 384, T / 32, F' / 32)
|
338 |
+
(x6_pool, x6) = self.encoder_block6(
|
339 |
+
x5_pool
|
340 |
+
) # x6_pool: (bs, 384, T / 64, F' / 64)
|
341 |
+
x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64)
|
342 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32)
|
343 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16)
|
344 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8)
|
345 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4)
|
346 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2)
|
347 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F')
|
348 |
+
x = self.after_conv_block1(x12) # (bs, 32, T, F')
|
349 |
+
|
350 |
+
x = self.after_conv2(x)
|
351 |
+
# (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
|
352 |
+
|
353 |
+
# Recover shape
|
354 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 256 -> 257.
|
355 |
+
|
356 |
+
x = x[:, :, 0:origin_len, :]
|
357 |
+
# (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
|
358 |
+
|
359 |
+
audio_length = subband_x.shape[2]
|
360 |
+
|
361 |
+
# Recover each subband spectrograms to subband waveforms. Then synthesis
|
362 |
+
# the subband waveforms to a waveform.
|
363 |
+
C1 = x.shape[1] // self.subbands_num
|
364 |
+
C2 = mag.shape[1] // self.subbands_num
|
365 |
+
|
366 |
+
separated_subband_audio = torch.cat(
|
367 |
+
[
|
368 |
+
self.feature_maps_to_wav(
|
369 |
+
input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
|
370 |
+
sp=mag[:, j * C2 : (j + 1) * C2, :, :],
|
371 |
+
sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
|
372 |
+
cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
|
373 |
+
audio_length=audio_length,
|
374 |
+
)
|
375 |
+
for j in range(self.subbands_num)
|
376 |
+
],
|
377 |
+
dim=1,
|
378 |
+
)
|
379 |
+
# (batch_size, subbands_num * target_sources_num * input_channles, segment_samples)
|
380 |
+
|
381 |
+
if self.subbands_num > 1:
|
382 |
+
separated_audio = self.pqmf.synthesis(separated_subband_audio)
|
383 |
+
# (batch_size, target_sources_num * input_channles, segment_samples)
|
384 |
+
else:
|
385 |
+
separated_audio = separated_subband_audio
|
386 |
+
|
387 |
+
output_dict = {'waveform': separated_audio}
|
388 |
+
|
389 |
+
return output_dict
|
bytesep/optimizers/__init__.py
ADDED
File without changes
|
bytesep/optimizers/lr_schedulers.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_lr_lambda(step, warm_up_steps: int, reduce_lr_steps: int):
|
2 |
+
r"""Get lr_lambda for LambdaLR. E.g.,
|
3 |
+
|
4 |
+
.. code-block: python
|
5 |
+
lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000)
|
6 |
+
|
7 |
+
from torch.optim.lr_scheduler import LambdaLR
|
8 |
+
LambdaLR(optimizer, lr_lambda)
|
9 |
+
|
10 |
+
Args:
|
11 |
+
warm_up_steps: int, steps for warm up
|
12 |
+
reduce_lr_steps: int, reduce learning rate by 0.9 every #reduce_lr_steps steps
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
learning rate: float
|
16 |
+
"""
|
17 |
+
if step <= warm_up_steps:
|
18 |
+
return step / warm_up_steps
|
19 |
+
else:
|
20 |
+
return 0.9 ** (step // reduce_lr_steps)
|
bytesep/plot_results/__init__.py
ADDED
File without changes
|
bytesep/plot_results/musdb18.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def load_sdrs(workspace, task_name, filename, config, gpus, source_type):
|
10 |
+
|
11 |
+
stat_path = os.path.join(
|
12 |
+
workspace,
|
13 |
+
"statistics",
|
14 |
+
task_name,
|
15 |
+
filename,
|
16 |
+
"config={},gpus={}".format(config, gpus),
|
17 |
+
"statistics.pkl",
|
18 |
+
)
|
19 |
+
|
20 |
+
stat_dict = pickle.load(open(stat_path, 'rb'))
|
21 |
+
|
22 |
+
median_sdrs = [e['median_sdr_dict'][source_type] for e in stat_dict['test']]
|
23 |
+
|
24 |
+
return median_sdrs
|
25 |
+
|
26 |
+
|
27 |
+
def plot_statistics(args):
|
28 |
+
|
29 |
+
# arguments & parameters
|
30 |
+
workspace = args.workspace
|
31 |
+
select = args.select
|
32 |
+
task_name = "musdb18"
|
33 |
+
filename = "train"
|
34 |
+
|
35 |
+
# paths
|
36 |
+
fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select))
|
37 |
+
os.makedirs(os.path.dirname(fig_path), exist_ok=True)
|
38 |
+
|
39 |
+
linewidth = 1
|
40 |
+
lines = []
|
41 |
+
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
|
42 |
+
|
43 |
+
if select == '1a':
|
44 |
+
sdrs = load_sdrs(
|
45 |
+
workspace,
|
46 |
+
task_name,
|
47 |
+
filename,
|
48 |
+
config='vocals-accompaniment,unet',
|
49 |
+
gpus=1,
|
50 |
+
source_type="vocals",
|
51 |
+
)
|
52 |
+
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
|
53 |
+
lines.append(line)
|
54 |
+
ylim = 15
|
55 |
+
|
56 |
+
elif select == '1b':
|
57 |
+
sdrs = load_sdrs(
|
58 |
+
workspace,
|
59 |
+
task_name,
|
60 |
+
filename,
|
61 |
+
config='accompaniment-vocals,unet',
|
62 |
+
gpus=1,
|
63 |
+
source_type="accompaniment",
|
64 |
+
)
|
65 |
+
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
|
66 |
+
lines.append(line)
|
67 |
+
ylim = 20
|
68 |
+
|
69 |
+
if select == '1c':
|
70 |
+
sdrs = load_sdrs(
|
71 |
+
workspace,
|
72 |
+
task_name,
|
73 |
+
filename,
|
74 |
+
config='vocals-accompaniment,unet',
|
75 |
+
gpus=1,
|
76 |
+
source_type="vocals",
|
77 |
+
)
|
78 |
+
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
|
79 |
+
lines.append(line)
|
80 |
+
|
81 |
+
sdrs = load_sdrs(
|
82 |
+
workspace,
|
83 |
+
task_name,
|
84 |
+
filename,
|
85 |
+
config='vocals-accompaniment,resunet',
|
86 |
+
gpus=2,
|
87 |
+
source_type="vocals",
|
88 |
+
)
|
89 |
+
(line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth)
|
90 |
+
lines.append(line)
|
91 |
+
|
92 |
+
sdrs = load_sdrs(
|
93 |
+
workspace,
|
94 |
+
task_name,
|
95 |
+
filename,
|
96 |
+
config='vocals-accompaniment,unet_subbandtime',
|
97 |
+
gpus=1,
|
98 |
+
source_type="vocals",
|
99 |
+
)
|
100 |
+
(line,) = ax.plot(sdrs, label='unet_subband,l1_wav', linewidth=linewidth)
|
101 |
+
lines.append(line)
|
102 |
+
|
103 |
+
sdrs = load_sdrs(
|
104 |
+
workspace,
|
105 |
+
task_name,
|
106 |
+
filename,
|
107 |
+
config='vocals-accompaniment,resunet_subbandtime',
|
108 |
+
gpus=1,
|
109 |
+
source_type="vocals",
|
110 |
+
)
|
111 |
+
(line,) = ax.plot(sdrs, label='resunet_subband,l1_wav', linewidth=linewidth)
|
112 |
+
lines.append(line)
|
113 |
+
|
114 |
+
ylim = 15
|
115 |
+
|
116 |
+
elif select == '1d':
|
117 |
+
sdrs = load_sdrs(
|
118 |
+
workspace,
|
119 |
+
task_name,
|
120 |
+
filename,
|
121 |
+
config='accompaniment-vocals,unet',
|
122 |
+
gpus=1,
|
123 |
+
source_type="accompaniment",
|
124 |
+
)
|
125 |
+
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
|
126 |
+
lines.append(line)
|
127 |
+
|
128 |
+
sdrs = load_sdrs(
|
129 |
+
workspace,
|
130 |
+
task_name,
|
131 |
+
filename,
|
132 |
+
config='accompaniment-vocals,resunet',
|
133 |
+
gpus=2,
|
134 |
+
source_type="accompaniment",
|
135 |
+
)
|
136 |
+
(line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth)
|
137 |
+
lines.append(line)
|
138 |
+
|
139 |
+
# sdrs = load_sdrs(
|
140 |
+
# workspace,
|
141 |
+
# task_name,
|
142 |
+
# filename,
|
143 |
+
# config='accompaniment-vocals,unet_subbandtime',
|
144 |
+
# gpus=1,
|
145 |
+
# source_type="accompaniment",
|
146 |
+
# )
|
147 |
+
# (line,) = ax.plot(sdrs, label='UNet_subbtandtime,l1_wav', linewidth=linewidth)
|
148 |
+
# lines.append(line)
|
149 |
+
|
150 |
+
sdrs = load_sdrs(
|
151 |
+
workspace,
|
152 |
+
task_name,
|
153 |
+
filename,
|
154 |
+
config='accompaniment-vocals,resunet_subbandtime',
|
155 |
+
gpus=1,
|
156 |
+
source_type="accompaniment",
|
157 |
+
)
|
158 |
+
(line,) = ax.plot(
|
159 |
+
sdrs, label='ResUNet_subbtandtime,l1_wav', linewidth=linewidth
|
160 |
+
)
|
161 |
+
lines.append(line)
|
162 |
+
|
163 |
+
ylim = 20
|
164 |
+
|
165 |
+
else:
|
166 |
+
raise Exception('Error!')
|
167 |
+
|
168 |
+
eval_every_iterations = 10000
|
169 |
+
total_ticks = 50
|
170 |
+
ticks_freq = 10
|
171 |
+
|
172 |
+
ax.set_ylim(0, ylim)
|
173 |
+
ax.set_xlim(0, total_ticks)
|
174 |
+
ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq))
|
175 |
+
ax.xaxis.set_ticklabels(
|
176 |
+
np.arange(
|
177 |
+
0,
|
178 |
+
total_ticks * eval_every_iterations + 1,
|
179 |
+
ticks_freq * eval_every_iterations,
|
180 |
+
)
|
181 |
+
)
|
182 |
+
ax.yaxis.set_ticks(np.arange(ylim + 1))
|
183 |
+
ax.yaxis.set_ticklabels(np.arange(ylim + 1))
|
184 |
+
ax.grid(color='b', linestyle='solid', linewidth=0.3)
|
185 |
+
plt.legend(handles=lines, loc=4)
|
186 |
+
|
187 |
+
plt.savefig(fig_path)
|
188 |
+
print('Save figure to {}'.format(fig_path))
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == '__main__':
|
192 |
+
parser = argparse.ArgumentParser()
|
193 |
+
parser.add_argument('--workspace', type=str, required=True)
|
194 |
+
parser.add_argument('--select', type=str, required=True)
|
195 |
+
|
196 |
+
args = parser.parse_args()
|
197 |
+
|
198 |
+
plot_statistics(args)
|
bytesep/plot_results/plot_vctk-musdb18.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import h5py
|
6 |
+
import math
|
7 |
+
import time
|
8 |
+
import logging
|
9 |
+
import pickle
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
|
12 |
+
|
13 |
+
def load_sdrs(workspace, task_name, filename, config, gpus):
|
14 |
+
|
15 |
+
stat_path = os.path.join(
|
16 |
+
workspace,
|
17 |
+
"statistics",
|
18 |
+
task_name,
|
19 |
+
filename,
|
20 |
+
"config={},gpus={}".format(config, gpus),
|
21 |
+
"statistics.pkl",
|
22 |
+
)
|
23 |
+
|
24 |
+
stat_dict = pickle.load(open(stat_path, 'rb'))
|
25 |
+
|
26 |
+
median_sdrs = [e['sdr'] for e in stat_dict['test']]
|
27 |
+
|
28 |
+
return median_sdrs
|
29 |
+
|
30 |
+
|
31 |
+
def plot_statistics(args):
|
32 |
+
|
33 |
+
# arguments & parameters
|
34 |
+
workspace = args.workspace
|
35 |
+
select = args.select
|
36 |
+
task_name = "vctk-musdb18"
|
37 |
+
filename = "train"
|
38 |
+
|
39 |
+
# paths
|
40 |
+
fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select))
|
41 |
+
os.makedirs(os.path.dirname(fig_path), exist_ok=True)
|
42 |
+
|
43 |
+
linewidth = 1
|
44 |
+
lines = []
|
45 |
+
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
|
46 |
+
ylim = 30
|
47 |
+
expand = 1
|
48 |
+
|
49 |
+
if select == '1a':
|
50 |
+
sdrs = load_sdrs(workspace, task_name, filename, config='unet', gpus=1)
|
51 |
+
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
|
52 |
+
lines.append(line)
|
53 |
+
|
54 |
+
else:
|
55 |
+
raise Exception('Error!')
|
56 |
+
|
57 |
+
eval_every_iterations = 10000
|
58 |
+
total_ticks = 50
|
59 |
+
ticks_freq = 10
|
60 |
+
|
61 |
+
ax.set_ylim(0, ylim)
|
62 |
+
ax.set_xlim(0, total_ticks)
|
63 |
+
ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq))
|
64 |
+
ax.xaxis.set_ticklabels(
|
65 |
+
np.arange(
|
66 |
+
0,
|
67 |
+
total_ticks * eval_every_iterations + 1,
|
68 |
+
ticks_freq * eval_every_iterations,
|
69 |
+
)
|
70 |
+
)
|
71 |
+
ax.yaxis.set_ticks(np.arange(ylim + 1))
|
72 |
+
ax.yaxis.set_ticklabels(np.arange(ylim + 1))
|
73 |
+
ax.grid(color='b', linestyle='solid', linewidth=0.3)
|
74 |
+
plt.legend(handles=lines, loc=4)
|
75 |
+
|
76 |
+
plt.savefig(fig_path)
|
77 |
+
print('Save figure to {}'.format(fig_path))
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == '__main__':
|
81 |
+
parser = argparse.ArgumentParser()
|
82 |
+
parser.add_argument('--workspace', type=str, required=True)
|
83 |
+
parser.add_argument('--select', type=str, required=True)
|
84 |
+
|
85 |
+
args = parser.parse_args()
|
86 |
+
|
87 |
+
plot_statistics(args)
|