Spaces:
Running
Running
Upload 131 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +20 -0
- audio_separator/__init__.py +0 -0
- audio_separator/model-data.json +22 -0
- audio_separator/models-scores.json +0 -0
- audio_separator/models.json +216 -0
- audio_separator/separator/__init__.py +1 -0
- audio_separator/separator/architectures/__init__.py +0 -0
- audio_separator/separator/architectures/demucs_separator.py +195 -0
- audio_separator/separator/architectures/mdx_separator.py +451 -0
- audio_separator/separator/architectures/mdxc_separator.py +423 -0
- audio_separator/separator/architectures/vr_separator.py +357 -0
- audio_separator/separator/common_separator.py +403 -0
- audio_separator/separator/separator.py +959 -0
- audio_separator/separator/uvr_lib_v5/__init__.py +0 -0
- audio_separator/separator/uvr_lib_v5/demucs/__init__.py +5 -0
- audio_separator/separator/uvr_lib_v5/demucs/__main__.py +212 -0
- audio_separator/separator/uvr_lib_v5/demucs/apply.py +294 -0
- audio_separator/separator/uvr_lib_v5/demucs/demucs.py +453 -0
- audio_separator/separator/uvr_lib_v5/demucs/filtering.py +451 -0
- audio_separator/separator/uvr_lib_v5/demucs/hdemucs.py +783 -0
- audio_separator/separator/uvr_lib_v5/demucs/htdemucs.py +620 -0
- audio_separator/separator/uvr_lib_v5/demucs/model.py +204 -0
- audio_separator/separator/uvr_lib_v5/demucs/model_v2.py +222 -0
- audio_separator/separator/uvr_lib_v5/demucs/pretrained.py +181 -0
- audio_separator/separator/uvr_lib_v5/demucs/repo.py +146 -0
- audio_separator/separator/uvr_lib_v5/demucs/spec.py +38 -0
- audio_separator/separator/uvr_lib_v5/demucs/states.py +131 -0
- audio_separator/separator/uvr_lib_v5/demucs/tasnet.py +401 -0
- audio_separator/separator/uvr_lib_v5/demucs/tasnet_v2.py +404 -0
- audio_separator/separator/uvr_lib_v5/demucs/transformer.py +675 -0
- audio_separator/separator/uvr_lib_v5/demucs/utils.py +496 -0
- audio_separator/separator/uvr_lib_v5/mdxnet.py +136 -0
- audio_separator/separator/uvr_lib_v5/mixer.ckpt +3 -0
- audio_separator/separator/uvr_lib_v5/modules.py +74 -0
- audio_separator/separator/uvr_lib_v5/playsound.py +241 -0
- audio_separator/separator/uvr_lib_v5/pyrb.py +92 -0
- audio_separator/separator/uvr_lib_v5/results.py +48 -0
- audio_separator/separator/uvr_lib_v5/roformer/attend.py +112 -0
- audio_separator/separator/uvr_lib_v5/roformer/bs_roformer.py +535 -0
- audio_separator/separator/uvr_lib_v5/roformer/mel_band_roformer.py +445 -0
- audio_separator/separator/uvr_lib_v5/spec_utils.py +1327 -0
- audio_separator/separator/uvr_lib_v5/stft.py +126 -0
- audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py +253 -0
- audio_separator/separator/uvr_lib_v5/vr_network/__init__.py +1 -0
- audio_separator/separator/uvr_lib_v5/vr_network/layers.py +294 -0
- audio_separator/separator/uvr_lib_v5/vr_network/layers_new.py +149 -0
- audio_separator/separator/uvr_lib_v5/vr_network/model_param_init.py +71 -0
- audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr16000_hl512.json +19 -0
- audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr32000_hl512.json +19 -0
- audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr33075_hl384.json +19 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tests/inputs/mardy20s.flac filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
tests/inputs/reference/expected_mardy20s_(Bass)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
tests/inputs/reference/expected_mardy20s_(Drum-Bass)_model_bs_roformer_ep_937_sdr_10_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
tests/inputs/reference/expected_mardy20s_(Drums)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
tests/inputs/reference/expected_mardy20s_(Guitar)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
tests/inputs/reference/expected_mardy20s_(Instrumental)_2_HP-UVR_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
tests/inputs/reference/expected_mardy20s_(Instrumental)_kuielab_b_vocals_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
tests/inputs/reference/expected_mardy20s_(Instrumental)_MGM_MAIN_v4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
tests/inputs/reference/expected_mardy20s_(Instrumental)_model_bs_roformer_ep_317_sdr_12_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
tests/inputs/reference/expected_mardy20s_(Instrumental)_UVR-MDX-NET-Inst_HQ_4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
tests/inputs/reference/expected_mardy20s_(No[[:space:]]Drum-Bass)_model_bs_roformer_ep_937_sdr_10_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
tests/inputs/reference/expected_mardy20s_(Other)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
tests/inputs/reference/expected_mardy20s_(Piano)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_2_HP-UVR_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_htdemucs_6s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_kuielab_b_vocals_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_MGM_MAIN_v4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_model_bs_roformer_ep_317_sdr_12_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
tests/inputs/reference/expected_mardy20s_(Vocals)_UVR-MDX-NET-Inst_HQ_4_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
tests/inputs/reference/expected_mardy20s_spectrogram.png filter=lfs diff=lfs merge=lfs -text
|
audio_separator/__init__.py
ADDED
|
File without changes
|
audio_separator/model-data.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vr_model_data": {
|
| 3 |
+
"97dc361a7a88b2c4542f68364b32c7f6": {
|
| 4 |
+
"vr_model_param": "4band_v4_ms_fullband",
|
| 5 |
+
"primary_stem": "Dry",
|
| 6 |
+
"nout": 32,
|
| 7 |
+
"nout_lstm": 128,
|
| 8 |
+
"is_karaoke": false,
|
| 9 |
+
"is_bv_model": false,
|
| 10 |
+
"is_bv_model_rebalanced": 0.0
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"mdx_model_data": {
|
| 14 |
+
"cb790d0c913647ced70fc6b38f5bea1a": {
|
| 15 |
+
"compensate": 1.010,
|
| 16 |
+
"mdx_dim_f_set": 2560,
|
| 17 |
+
"mdx_dim_t_set": 8,
|
| 18 |
+
"mdx_n_fft_scale_set": 5120,
|
| 19 |
+
"primary_stem": "Instrumental"
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
}
|
audio_separator/models-scores.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
audio_separator/models.json
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vr_download_list": {
|
| 3 |
+
"VR Arch Single Model v4: UVR-De-Reverb by aufr33-jarredou": "UVR-De-Reverb-aufr33-jarredou.pth"
|
| 4 |
+
},
|
| 5 |
+
"mdx_download_list": {
|
| 6 |
+
"MDX-Net Model: UVR-MDX-NET Inst HQ 5": "UVR-MDX-NET-Inst_HQ_5.onnx"
|
| 7 |
+
},
|
| 8 |
+
"mdx23c_download_list": {
|
| 9 |
+
"MDX23C Model: MDX23C De-Reverb by aufr33-jarredou": {
|
| 10 |
+
"MDX23C-De-Reverb-aufr33-jarredou.ckpt": "config_dereverb_mdx23c.yaml"
|
| 11 |
+
},
|
| 12 |
+
"MDX23C Model: MDX23C DrumSep by aufr33-jarredou": {
|
| 13 |
+
"MDX23C-DrumSep-aufr33-jarredou.ckpt": "config_drumsep_mdx23c.yaml"
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"roformer_download_list": {
|
| 17 |
+
"Roformer Model: Mel-Roformer-Karaoke-Aufr33-Viperx": {
|
| 18 |
+
"mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt": "mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956_config.yaml"
|
| 19 |
+
},
|
| 20 |
+
"Roformer Model: MelBand Roformer | Karaoke by Gabox": {
|
| 21 |
+
"mel_band_roformer_karaoke_gabox.ckpt": "mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956_config.yaml"
|
| 22 |
+
},
|
| 23 |
+
"Roformer Model: MelBand Roformer | Karaoke by becruily": {
|
| 24 |
+
"mel_band_roformer_karaoke_becruily.ckpt": "config_mel_band_roformer_karaoke_becruily.yaml"
|
| 25 |
+
},
|
| 26 |
+
"Roformer Model: Mel-Roformer-Denoise-Aufr33": {
|
| 27 |
+
"denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt": "denoise_mel_band_roformer_aufr33_sdr_27.9959_config.yaml"
|
| 28 |
+
},
|
| 29 |
+
"Roformer Model: Mel-Roformer-Denoise-Aufr33-Aggr": {
|
| 30 |
+
"denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt": "denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768_config.yaml"
|
| 31 |
+
},
|
| 32 |
+
"Roformer Model: MelBand Roformer | Denoise-Debleed by Gabox": {
|
| 33 |
+
"mel_band_roformer_denoise_debleed_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 34 |
+
},
|
| 35 |
+
"Roformer Model: Mel-Roformer-Crowd-Aufr33-Viperx": {
|
| 36 |
+
"mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt": "mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144_config.yaml"
|
| 37 |
+
},
|
| 38 |
+
"Roformer Model: BS-Roformer-De-Reverb": {
|
| 39 |
+
"deverb_bs_roformer_8_384dim_10depth.ckpt": "deverb_bs_roformer_8_384dim_10depth_config.yaml"
|
| 40 |
+
},
|
| 41 |
+
"Roformer Model: MelBand Roformer | Vocals by Kimberley Jensen": {
|
| 42 |
+
"vocals_mel_band_roformer.ckpt": "vocals_mel_band_roformer.yaml"
|
| 43 |
+
},
|
| 44 |
+
"Roformer Model: MelBand Roformer Kim | FT by unwa": {
|
| 45 |
+
"mel_band_roformer_kim_ft_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
|
| 46 |
+
},
|
| 47 |
+
"Roformer Model: MelBand Roformer Kim | FT 2 by unwa": {
|
| 48 |
+
"mel_band_roformer_kim_ft2_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
|
| 49 |
+
},
|
| 50 |
+
"Roformer Model: MelBand Roformer Kim | FT 2 Bleedless by unwa": {
|
| 51 |
+
"mel_band_roformer_kim_ft2_bleedless_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
|
| 52 |
+
},
|
| 53 |
+
"Roformer Model: MelBand Roformer Kim | FT 3 by unwa": {
|
| 54 |
+
"mel_band_roformer_kim_ft3_unwa.ckpt": "config_mel_band_roformer_kim_ft_unwa.yaml"
|
| 55 |
+
},
|
| 56 |
+
"Roformer Model: MelBand Roformer Kim | Inst V1 Plus by Unwa": {
|
| 57 |
+
"melband_roformer_inst_v1_plus.ckpt": "config_melbandroformer_inst.yaml"
|
| 58 |
+
},
|
| 59 |
+
"Roformer Model: MelBand Roformer Kim | Inst V1 (E) by Unwa": {
|
| 60 |
+
"melband_roformer_inst_v1e.ckpt": "config_melbandroformer_inst.yaml"
|
| 61 |
+
},
|
| 62 |
+
"Roformer Model: MelBand Roformer Kim | Inst V1 (E) Plus by Unwa": {
|
| 63 |
+
"melband_roformer_inst_v1e_plus.ckpt": "config_melbandroformer_inst.yaml"
|
| 64 |
+
},
|
| 65 |
+
"Roformer Model: MelBand Roformer | Vocals by becruily": {
|
| 66 |
+
"mel_band_roformer_vocals_becruily.ckpt": "config_mel_band_roformer_vocals_becruily.yaml"
|
| 67 |
+
},
|
| 68 |
+
"Roformer Model: MelBand Roformer | Instrumental by becruily": {
|
| 69 |
+
"mel_band_roformer_instrumental_becruily.ckpt": "config_mel_band_roformer_instrumental_becruily.yaml"
|
| 70 |
+
},
|
| 71 |
+
"Roformer Model: MelBand Roformer | Vocals Fullness by Aname": {
|
| 72 |
+
"mel_band_roformer_vocal_fullness_aname.ckpt": "config_mel_band_roformer_vocal_fullness_aname.yaml"
|
| 73 |
+
},
|
| 74 |
+
"Roformer Model: BS Roformer | Vocals by Gabox": {
|
| 75 |
+
"bs_roformer_vocals_gabox.ckpt": "config_bs_roformer_vocals_gabox.yaml"
|
| 76 |
+
},
|
| 77 |
+
"Roformer Model: MelBand Roformer | Vocals by Gabox": {
|
| 78 |
+
"mel_band_roformer_vocals_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
|
| 79 |
+
},
|
| 80 |
+
"Roformer Model: MelBand Roformer | Vocals FV1 by Gabox": {
|
| 81 |
+
"mel_band_roformer_vocals_fv1_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
|
| 82 |
+
},
|
| 83 |
+
"Roformer Model: MelBand Roformer | Vocals FV2 by Gabox": {
|
| 84 |
+
"mel_band_roformer_vocals_fv2_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
|
| 85 |
+
},
|
| 86 |
+
"Roformer Model: MelBand Roformer | Vocals FV3 by Gabox": {
|
| 87 |
+
"mel_band_roformer_vocals_fv3_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
|
| 88 |
+
},
|
| 89 |
+
"Roformer Model: MelBand Roformer | Vocals FV4 by Gabox": {
|
| 90 |
+
"mel_band_roformer_vocals_fv4_gabox.ckpt": "config_mel_band_roformer_vocals_gabox.yaml"
|
| 91 |
+
},
|
| 92 |
+
"Roformer Model: MelBand Roformer | Instrumental by Gabox": {
|
| 93 |
+
"mel_band_roformer_instrumental_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 94 |
+
},
|
| 95 |
+
"Roformer Model: MelBand Roformer | Instrumental 2 by Gabox": {
|
| 96 |
+
"mel_band_roformer_instrumental_2_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 97 |
+
},
|
| 98 |
+
"Roformer Model: MelBand Roformer | Instrumental 3 by Gabox": {
|
| 99 |
+
"mel_band_roformer_instrumental_3_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 100 |
+
},
|
| 101 |
+
"Roformer Model: MelBand Roformer | Instrumental Bleedless V1 by Gabox": {
|
| 102 |
+
"mel_band_roformer_instrumental_bleedless_v1_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 103 |
+
},
|
| 104 |
+
"Roformer Model: MelBand Roformer | Instrumental Bleedless V2 by Gabox": {
|
| 105 |
+
"mel_band_roformer_instrumental_bleedless_v2_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 106 |
+
},
|
| 107 |
+
"Roformer Model: MelBand Roformer | Instrumental Bleedless V3 by Gabox": {
|
| 108 |
+
"mel_band_roformer_instrumental_bleedless_v3_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 109 |
+
},
|
| 110 |
+
"Roformer Model: MelBand Roformer | Instrumental Fullness V1 by Gabox": {
|
| 111 |
+
"mel_band_roformer_instrumental_fullness_v1_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 112 |
+
},
|
| 113 |
+
"Roformer Model: MelBand Roformer | Instrumental Fullness V2 by Gabox": {
|
| 114 |
+
"mel_band_roformer_instrumental_fullness_v2_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 115 |
+
},
|
| 116 |
+
"Roformer Model: MelBand Roformer | Instrumental Fullness V3 by Gabox": {
|
| 117 |
+
"mel_band_roformer_instrumental_fullness_v3_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 118 |
+
},
|
| 119 |
+
"Roformer Model: MelBand Roformer | Instrumental Fullness Noisy V4 by Gabox": {
|
| 120 |
+
"mel_band_roformer_instrumental_fullness_noise_v4_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 121 |
+
},
|
| 122 |
+
"Roformer Model: MelBand Roformer | INSTV5 by Gabox": {
|
| 123 |
+
"mel_band_roformer_instrumental_instv5_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 124 |
+
},
|
| 125 |
+
"Roformer Model: MelBand Roformer | INSTV5N by Gabox": {
|
| 126 |
+
"mel_band_roformer_instrumental_instv5n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 127 |
+
},
|
| 128 |
+
"Roformer Model: MelBand Roformer | INSTV6 by Gabox": {
|
| 129 |
+
"mel_band_roformer_instrumental_instv6_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 130 |
+
},
|
| 131 |
+
"Roformer Model: MelBand Roformer | INSTV6N by Gabox": {
|
| 132 |
+
"mel_band_roformer_instrumental_instv6n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 133 |
+
},
|
| 134 |
+
"Roformer Model: MelBand Roformer | INSTV7 by Gabox": {
|
| 135 |
+
"mel_band_roformer_instrumental_instv7_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 136 |
+
},
|
| 137 |
+
"Roformer Model: MelBand Roformer | INSTV7N by Gabox": {
|
| 138 |
+
"mel_band_roformer_instrumental_instv7n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 139 |
+
},
|
| 140 |
+
"Roformer Model: MelBand Roformer | INSTV8 by Gabox": {
|
| 141 |
+
"mel_band_roformer_instrumental_instv8_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 142 |
+
},
|
| 143 |
+
"Roformer Model: MelBand Roformer | INSTV8N by Gabox": {
|
| 144 |
+
"mel_band_roformer_instrumental_instv8n_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 145 |
+
},
|
| 146 |
+
"Roformer Model: MelBand Roformer | FVX by Gabox": {
|
| 147 |
+
"mel_band_roformer_instrumental_fvx_gabox.ckpt": "config_mel_band_roformer_instrumental_gabox.yaml"
|
| 148 |
+
},
|
| 149 |
+
"Roformer Model: MelBand Roformer | De-Reverb by anvuew": {
|
| 150 |
+
"dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt": "dereverb_mel_band_roformer_anvuew.yaml"
|
| 151 |
+
},
|
| 152 |
+
"Roformer Model: MelBand Roformer | De-Reverb Less Aggressive by anvuew": {
|
| 153 |
+
"dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt": "dereverb_mel_band_roformer_anvuew.yaml"
|
| 154 |
+
},
|
| 155 |
+
"Roformer Model: MelBand Roformer | De-Reverb Mono by anvuew": {
|
| 156 |
+
"dereverb_mel_band_roformer_mono_anvuew.ckpt": "dereverb_mel_band_roformer_anvuew.yaml"
|
| 157 |
+
},
|
| 158 |
+
"Roformer Model: MelBand Roformer | De-Reverb Big by Sucial": {
|
| 159 |
+
"dereverb_big_mbr_ep_362.ckpt": "config_dereverb_echo_mel_band_roformer_v2.yaml"
|
| 160 |
+
},
|
| 161 |
+
"Roformer Model: MelBand Roformer | De-Reverb Super Big by Sucial": {
|
| 162 |
+
"dereverb_super_big_mbr_ep_346.ckpt": "config_dereverb_echo_mel_band_roformer_v2.yaml"
|
| 163 |
+
},
|
| 164 |
+
"Roformer Model: MelBand Roformer | De-Reverb-Echo by Sucial": {
|
| 165 |
+
"dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt": "config_dereverb-echo_mel_band_roformer.yaml"
|
| 166 |
+
},
|
| 167 |
+
"Roformer Model: MelBand Roformer | De-Reverb-Echo V2 by Sucial": {
|
| 168 |
+
"dereverb-echo_mel_band_roformer_sdr_13.4843_v2.ckpt": "config_dereverb-echo_mel_band_roformer_sdr_13.4843_v2.yaml"
|
| 169 |
+
},
|
| 170 |
+
"Roformer Model: MelBand Roformer | De-Reverb-Echo Fused by Sucial": {
|
| 171 |
+
"dereverb_echo_mbr_fused.ckpt": "config_dereverb_echo_mel_band_roformer_v2.yaml"
|
| 172 |
+
},
|
| 173 |
+
"Roformer Model: MelBand Roformer Kim | SYHFT by SYH99999": {
|
| 174 |
+
"MelBandRoformerSYHFT.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
|
| 175 |
+
},
|
| 176 |
+
"Roformer Model: MelBand Roformer Kim | SYHFT V2 by SYH99999": {
|
| 177 |
+
"MelBandRoformerSYHFTV2.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
|
| 178 |
+
},
|
| 179 |
+
"Roformer Model: MelBand Roformer Kim | SYHFT V2.5 by SYH99999": {
|
| 180 |
+
"MelBandRoformerSYHFTV2.5.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
|
| 181 |
+
},
|
| 182 |
+
"Roformer Model: MelBand Roformer Kim | SYHFT V3 by SYH99999": {
|
| 183 |
+
"MelBandRoformerSYHFTV3Epsilon.ckpt": "config_vocals_mel_band_roformer_ft.yaml"
|
| 184 |
+
},
|
| 185 |
+
"Roformer Model: MelBand Roformer Kim | Big SYHFT V1 by SYH99999": {
|
| 186 |
+
"MelBandRoformerBigSYHFTV1.ckpt": "config_vocals_mel_band_roformer_big_v1_ft.yaml"
|
| 187 |
+
},
|
| 188 |
+
"Roformer Model: MelBand Roformer Kim | Big Beta 4 FT by unwa": {
|
| 189 |
+
"melband_roformer_big_beta4.ckpt": "config_melbandroformer_big_beta4.yaml"
|
| 190 |
+
},
|
| 191 |
+
"Roformer Model: MelBand Roformer Kim | Big Beta 5e FT by unwa": {
|
| 192 |
+
"melband_roformer_big_beta5e.ckpt": "config_melband_roformer_big_beta5e.yaml"
|
| 193 |
+
},
|
| 194 |
+
"Roformer Model: MelBand Roformer | Big Beta 6 by unwa": {
|
| 195 |
+
"melband_roformer_big_beta6.ckpt": "config_melbandroformer_big_beta6.yaml"
|
| 196 |
+
},
|
| 197 |
+
"Roformer Model: MelBand Roformer | Big Beta 6X by unwa": {
|
| 198 |
+
"melband_roformer_big_beta6x.ckpt": "config_melbandroformer_big_beta6x.yaml"
|
| 199 |
+
},
|
| 200 |
+
"Roformer Model: BS Roformer | Chorus Male-Female by Sucial": {
|
| 201 |
+
"model_chorus_bs_roformer_ep_267_sdr_24.1275.ckpt": "config_chorus_male_female_bs_roformer.yaml"
|
| 202 |
+
},
|
| 203 |
+
"Roformer Model: BS Roformer | Male-Female by aufr33": {
|
| 204 |
+
"bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt": "config_chorus_male_female_bs_roformer.yaml"
|
| 205 |
+
},
|
| 206 |
+
"Roformer Model: MelBand Roformer | Aspiration by Sucial": {
|
| 207 |
+
"aspiration_mel_band_roformer_sdr_18.9845.ckpt": "config_aspiration_mel_band_roformer.yaml"
|
| 208 |
+
},
|
| 209 |
+
"Roformer Model: MelBand Roformer | Aspiration Less Aggressive by Sucial": {
|
| 210 |
+
"aspiration_mel_band_roformer_less_aggr_sdr_18.1201.ckpt": "config_aspiration_mel_band_roformer.yaml"
|
| 211 |
+
},
|
| 212 |
+
"Roformer Model: MelBand Roformer | Bleed Suppressor V1 by unwa-97chris": {
|
| 213 |
+
"mel_band_roformer_bleed_suppressor_v1.ckpt": "config_mel_band_roformer_bleed_suppressor_v1.yaml"
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
}
|
audio_separator/separator/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .separator import Separator
|
audio_separator/separator/architectures/__init__.py
ADDED
|
File without changes
|
audio_separator/separator/architectures/demucs_separator.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from audio_separator.separator.common_separator import CommonSeparator
|
| 7 |
+
from audio_separator.separator.uvr_lib_v5.demucs.apply import apply_model, demucs_segments
|
| 8 |
+
from audio_separator.separator.uvr_lib_v5.demucs.hdemucs import HDemucs
|
| 9 |
+
from audio_separator.separator.uvr_lib_v5.demucs.pretrained import get_model as get_demucs_model
|
| 10 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
| 11 |
+
|
| 12 |
+
DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"]
|
| 13 |
+
|
| 14 |
+
DEMUCS_2_SOURCE_MAPPER = {CommonSeparator.INST_STEM: 0, CommonSeparator.VOCAL_STEM: 1}
|
| 15 |
+
DEMUCS_4_SOURCE_MAPPER = {CommonSeparator.BASS_STEM: 0, CommonSeparator.DRUM_STEM: 1, CommonSeparator.OTHER_STEM: 2, CommonSeparator.VOCAL_STEM: 3}
|
| 16 |
+
DEMUCS_6_SOURCE_MAPPER = {
|
| 17 |
+
CommonSeparator.BASS_STEM: 0,
|
| 18 |
+
CommonSeparator.DRUM_STEM: 1,
|
| 19 |
+
CommonSeparator.OTHER_STEM: 2,
|
| 20 |
+
CommonSeparator.VOCAL_STEM: 3,
|
| 21 |
+
CommonSeparator.GUITAR_STEM: 4,
|
| 22 |
+
CommonSeparator.PIANO_STEM: 5,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DemucsSeparator(CommonSeparator):
|
| 27 |
+
"""
|
| 28 |
+
DemucsSeparator is responsible for separating audio sources using Demucs models.
|
| 29 |
+
It initializes with configuration parameters and prepares the model for separation tasks.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, common_config, arch_config):
|
| 33 |
+
# Any configuration values which can be shared between architectures should be set already in CommonSeparator,
|
| 34 |
+
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
|
| 35 |
+
super().__init__(config=common_config)
|
| 36 |
+
|
| 37 |
+
# Initializing user-configurable parameters, passed through with an mdx_from the CLI or Separator instance
|
| 38 |
+
|
| 39 |
+
# Adjust segments to manage RAM or V-RAM usage:
|
| 40 |
+
# - Smaller sizes consume less resources.
|
| 41 |
+
# - Bigger sizes consume more resources, but may provide better results.
|
| 42 |
+
# - "Default" picks the optimal size.
|
| 43 |
+
# DEMUCS_SEGMENTS = (DEF_OPT, '1', '5', '10', '15', '20',
|
| 44 |
+
# '25', '30', '35', '40', '45', '50',
|
| 45 |
+
# '55', '60', '65', '70', '75', '80',
|
| 46 |
+
# '85', '90', '95', '100')
|
| 47 |
+
self.segment_size = arch_config.get("segment_size", "Default")
|
| 48 |
+
|
| 49 |
+
# Performs multiple predictions with random shifts of the input and averages them.
|
| 50 |
+
# The higher number of shifts, the longer the prediction will take.
|
| 51 |
+
# Not recommended unless you have a GPU.
|
| 52 |
+
# DEMUCS_SHIFTS = (0, 1, 2, 3, 4, 5,
|
| 53 |
+
# 6, 7, 8, 9, 10, 11,
|
| 54 |
+
# 12, 13, 14, 15, 16, 17,
|
| 55 |
+
# 18, 19, 20)
|
| 56 |
+
self.shifts = arch_config.get("shifts", 2)
|
| 57 |
+
|
| 58 |
+
# This option controls the amount of overlap between prediction windows.
|
| 59 |
+
# - Higher values can provide better results, but will lead to longer processing times.
|
| 60 |
+
# - You can choose between 0.001-0.999
|
| 61 |
+
# DEMUCS_OVERLAP = (0.25, 0.50, 0.75, 0.99)
|
| 62 |
+
self.overlap = arch_config.get("overlap", 0.25)
|
| 63 |
+
|
| 64 |
+
# Enables "Segments". Deselecting this option is only recommended for those with powerful PCs.
|
| 65 |
+
self.segments_enabled = arch_config.get("segments_enabled", True)
|
| 66 |
+
|
| 67 |
+
self.logger.debug(f"Demucs arch params: segment_size={self.segment_size}, segments_enabled={self.segments_enabled}")
|
| 68 |
+
self.logger.debug(f"Demucs arch params: shifts={self.shifts}, overlap={self.overlap}")
|
| 69 |
+
|
| 70 |
+
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
| 71 |
+
|
| 72 |
+
self.audio_file_path = None
|
| 73 |
+
self.audio_file_base = None
|
| 74 |
+
self.demucs_model_instance = None
|
| 75 |
+
|
| 76 |
+
# Add uvr_lib_v5 folder to system path so pytorch serialization can find the demucs module
|
| 77 |
+
current_dir = os.path.dirname(__file__)
|
| 78 |
+
uvr_lib_v5_path = os.path.join(current_dir, "..", "uvr_lib_v5")
|
| 79 |
+
sys.path.insert(0, uvr_lib_v5_path)
|
| 80 |
+
|
| 81 |
+
self.logger.info("Demucs Separator initialisation complete")
|
| 82 |
+
|
| 83 |
+
def separate(self, audio_file_path, custom_output_names=None):
|
| 84 |
+
"""
|
| 85 |
+
Separates the audio file into its component stems using the Demucs model.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
audio_file_path (str): The path to the audio file to be processed.
|
| 89 |
+
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
list: A list of paths to the output files generated by the separation process.
|
| 93 |
+
"""
|
| 94 |
+
self.logger.debug("Starting separation process...")
|
| 95 |
+
source = None
|
| 96 |
+
stem_source = None
|
| 97 |
+
inst_source = {}
|
| 98 |
+
|
| 99 |
+
self.audio_file_path = audio_file_path
|
| 100 |
+
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
| 101 |
+
|
| 102 |
+
# Prepare the mix for processing
|
| 103 |
+
self.logger.debug("Preparing mix...")
|
| 104 |
+
mix = self.prepare_mix(self.audio_file_path)
|
| 105 |
+
|
| 106 |
+
self.logger.debug(f"Mix prepared for demixing. Shape: {mix.shape}")
|
| 107 |
+
|
| 108 |
+
self.logger.debug("Loading model for demixing...")
|
| 109 |
+
|
| 110 |
+
self.demucs_model_instance = HDemucs(sources=DEMUCS_4_SOURCE)
|
| 111 |
+
self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=Path(os.path.dirname(self.model_path)))
|
| 112 |
+
self.demucs_model_instance = demucs_segments(self.segment_size, self.demucs_model_instance)
|
| 113 |
+
self.demucs_model_instance.to(self.torch_device)
|
| 114 |
+
self.demucs_model_instance.eval()
|
| 115 |
+
|
| 116 |
+
self.logger.debug("Model loaded and set to evaluation mode.")
|
| 117 |
+
|
| 118 |
+
source = self.demix_demucs(mix)
|
| 119 |
+
|
| 120 |
+
del self.demucs_model_instance
|
| 121 |
+
self.clear_gpu_cache()
|
| 122 |
+
self.logger.debug("Model and GPU cache cleared after demixing.")
|
| 123 |
+
|
| 124 |
+
output_files = []
|
| 125 |
+
self.logger.debug("Processing output files...")
|
| 126 |
+
|
| 127 |
+
if isinstance(inst_source, np.ndarray):
|
| 128 |
+
self.logger.debug("Processing instance source...")
|
| 129 |
+
source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]])
|
| 130 |
+
inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]] = source_reshape
|
| 131 |
+
source = inst_source
|
| 132 |
+
|
| 133 |
+
if isinstance(source, np.ndarray):
|
| 134 |
+
source_length = len(source)
|
| 135 |
+
self.logger.debug(f"Processing source array, source length is {source_length}")
|
| 136 |
+
match source_length:
|
| 137 |
+
case 2:
|
| 138 |
+
self.logger.debug("Setting source map to 2-stem...")
|
| 139 |
+
self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
|
| 140 |
+
case 6:
|
| 141 |
+
self.logger.debug("Setting source map to 6-stem...")
|
| 142 |
+
self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER
|
| 143 |
+
case _:
|
| 144 |
+
self.logger.debug("Setting source map to 4-stem...")
|
| 145 |
+
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
| 146 |
+
|
| 147 |
+
self.logger.debug("Processing for all stems...")
|
| 148 |
+
for stem_name, stem_value in self.demucs_source_map.items():
|
| 149 |
+
if self.output_single_stem is not None:
|
| 150 |
+
if stem_name.lower() != self.output_single_stem.lower():
|
| 151 |
+
self.logger.debug(f"Skipping writing stem {stem_name} as output_single_stem is set to {self.output_single_stem}...")
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
stem_path = self.get_stem_output_path(stem_name, custom_output_names)
|
| 155 |
+
stem_source = source[stem_value].T
|
| 156 |
+
|
| 157 |
+
self.final_process(stem_path, stem_source, stem_name)
|
| 158 |
+
output_files.append(stem_path)
|
| 159 |
+
|
| 160 |
+
return output_files
|
| 161 |
+
|
| 162 |
+
def demix_demucs(self, mix):
|
| 163 |
+
"""
|
| 164 |
+
Demixes the input mix using the demucs model.
|
| 165 |
+
"""
|
| 166 |
+
self.logger.debug("Starting demixing process in demix_demucs...")
|
| 167 |
+
|
| 168 |
+
processed = {}
|
| 169 |
+
mix = torch.tensor(mix, dtype=torch.float32)
|
| 170 |
+
ref = mix.mean(0)
|
| 171 |
+
mix = (mix - ref.mean()) / ref.std()
|
| 172 |
+
mix_infer = mix
|
| 173 |
+
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
self.logger.debug("Running model inference...")
|
| 176 |
+
sources = apply_model(
|
| 177 |
+
model=self.demucs_model_instance,
|
| 178 |
+
mix=mix_infer[None],
|
| 179 |
+
shifts=self.shifts,
|
| 180 |
+
split=self.segments_enabled,
|
| 181 |
+
overlap=self.overlap,
|
| 182 |
+
static_shifts=1 if self.shifts == 0 else self.shifts,
|
| 183 |
+
set_progress_bar=None,
|
| 184 |
+
device=self.torch_device,
|
| 185 |
+
progress=True,
|
| 186 |
+
)[0]
|
| 187 |
+
|
| 188 |
+
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
| 189 |
+
sources[[0, 1]] = sources[[1, 0]]
|
| 190 |
+
processed[mix] = sources[:, :, 0:None].copy()
|
| 191 |
+
sources = list(processed.values())
|
| 192 |
+
sources = [s[:, :, 0:None] for s in sources]
|
| 193 |
+
sources = np.concatenate(sources, axis=-1)
|
| 194 |
+
|
| 195 |
+
return sources
|
audio_separator/separator/architectures/mdx_separator.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module for separating audio sources using MDX architecture models."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import platform
|
| 5 |
+
import torch
|
| 6 |
+
import onnx
|
| 7 |
+
import onnxruntime as ort
|
| 8 |
+
import numpy as np
|
| 9 |
+
import onnx2torch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
| 12 |
+
from audio_separator.separator.uvr_lib_v5.stft import STFT
|
| 13 |
+
from audio_separator.separator.common_separator import CommonSeparator
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MDXSeparator(CommonSeparator):
|
| 17 |
+
"""
|
| 18 |
+
MDXSeparator is responsible for separating audio sources using MDX models.
|
| 19 |
+
It initializes with configuration parameters and prepares the model for separation tasks.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, common_config, arch_config):
|
| 23 |
+
# Any configuration values which can be shared between architectures should be set already in CommonSeparator,
|
| 24 |
+
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
|
| 25 |
+
super().__init__(config=common_config)
|
| 26 |
+
|
| 27 |
+
# Initializing user-configurable parameters, passed through with an mdx_from the CLI or Separator instance
|
| 28 |
+
|
| 29 |
+
# Pick a segment size to balance speed, resource use, and quality:
|
| 30 |
+
# - Smaller sizes consume less resources.
|
| 31 |
+
# - Bigger sizes consume more resources, but may provide better results.
|
| 32 |
+
# - Default size is 256. Quality can change based on your pick.
|
| 33 |
+
self.segment_size = arch_config.get("segment_size")
|
| 34 |
+
|
| 35 |
+
# This option controls the amount of overlap between prediction windows.
|
| 36 |
+
# - Higher values can provide better results, but will lead to longer processing times.
|
| 37 |
+
# - For Non-MDX23C models: You can choose between 0.001-0.999
|
| 38 |
+
self.overlap = arch_config.get("overlap")
|
| 39 |
+
|
| 40 |
+
# Number of batches to be processed at a time.
|
| 41 |
+
# - Higher values mean more RAM usage but slightly faster processing times.
|
| 42 |
+
# - Lower values mean less RAM usage but slightly longer processing times.
|
| 43 |
+
# - Batch size value has no effect on output quality.
|
| 44 |
+
# BATCH_SIZE = ('1', ''2', '3', '4', '5', '6', '7', '8', '9', '10')
|
| 45 |
+
self.batch_size = arch_config.get("batch_size", 1)
|
| 46 |
+
|
| 47 |
+
# hop_length is equivalent to the more commonly used term "stride" in convolutional neural networks
|
| 48 |
+
# In machine learning, particularly in the context of convolutional neural networks (CNNs),
|
| 49 |
+
# the term "stride" refers to the number of pixels by which we move the filter across the input image.
|
| 50 |
+
# Strides are a crucial component in the convolution operation, a fundamental building block of CNNs used primarily in the field of computer vision.
|
| 51 |
+
# Stride is a parameter that dictates the movement of the kernel, or filter, across the input data, such as an image.
|
| 52 |
+
# When performing a convolution operation, the stride determines how many units the filter shifts at each step.
|
| 53 |
+
# The choice of stride affects the model in several ways:
|
| 54 |
+
# Output Size: A larger stride will result in a smaller output spatial dimension.
|
| 55 |
+
# Computational Efficiency: Increasing the stride can decrease the computational load.
|
| 56 |
+
# Field of View: A higher stride means that each step of the filter takes into account a wider area of the input image.
|
| 57 |
+
# This can be beneficial when the model needs to capture more global features rather than focusing on finer details.
|
| 58 |
+
self.hop_length = arch_config.get("hop_length")
|
| 59 |
+
|
| 60 |
+
# If enabled, model will be run twice to reduce noise in output audio.
|
| 61 |
+
self.enable_denoise = arch_config.get("enable_denoise")
|
| 62 |
+
|
| 63 |
+
self.logger.debug(f"MDX arch params: batch_size={self.batch_size}, segment_size={self.segment_size}")
|
| 64 |
+
self.logger.debug(f"MDX arch params: overlap={self.overlap}, hop_length={self.hop_length}, enable_denoise={self.enable_denoise}")
|
| 65 |
+
|
| 66 |
+
# Initializing model-specific parameters from model_data JSON
|
| 67 |
+
self.compensate = self.model_data["compensate"]
|
| 68 |
+
self.dim_f = self.model_data["mdx_dim_f_set"]
|
| 69 |
+
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
|
| 70 |
+
self.n_fft = self.model_data["mdx_n_fft_scale_set"]
|
| 71 |
+
self.config_yaml = self.model_data.get("config_yaml", None)
|
| 72 |
+
|
| 73 |
+
self.logger.debug(f"MDX arch params: compensate={self.compensate}, dim_f={self.dim_f}, dim_t={self.dim_t}, n_fft={self.n_fft}")
|
| 74 |
+
self.logger.debug(f"MDX arch params: config_yaml={self.config_yaml}")
|
| 75 |
+
|
| 76 |
+
# In UVR, these variables are set but either aren't useful or are better handled in audio-separator.
|
| 77 |
+
# Leaving these comments explaining to help myself or future developers understand why these aren't in audio-separator.
|
| 78 |
+
|
| 79 |
+
# "chunks" is not actually used for anything in UVR...
|
| 80 |
+
# self.chunks = 0
|
| 81 |
+
|
| 82 |
+
# "adjust" is hard-coded to 1 in UVR, and only used as a multiplier in run_model, so it does nothing.
|
| 83 |
+
# self.adjust = 1
|
| 84 |
+
|
| 85 |
+
# "hop" is hard-coded to 1024 in UVR. We have a "hop_length" parameter instead
|
| 86 |
+
# self.hop = 1024
|
| 87 |
+
|
| 88 |
+
# "margin" maps to sample rate and is set from the GUI in UVR (default: 44100). We have a "sample_rate" parameter instead.
|
| 89 |
+
# self.margin = 44100
|
| 90 |
+
|
| 91 |
+
# "dim_c" is hard-coded to 4 in UVR, seems to be a parameter for the number of channels, and is only used for checkpoint models.
|
| 92 |
+
# We haven't implemented support for the checkpoint models here, so we're not using it.
|
| 93 |
+
# self.dim_c = 4
|
| 94 |
+
|
| 95 |
+
self.load_model()
|
| 96 |
+
|
| 97 |
+
self.n_bins = 0
|
| 98 |
+
self.trim = 0
|
| 99 |
+
self.chunk_size = 0
|
| 100 |
+
self.gen_size = 0
|
| 101 |
+
self.stft = None
|
| 102 |
+
|
| 103 |
+
self.primary_source = None
|
| 104 |
+
self.secondary_source = None
|
| 105 |
+
self.audio_file_path = None
|
| 106 |
+
self.audio_file_base = None
|
| 107 |
+
|
| 108 |
+
def load_model(self):
|
| 109 |
+
"""
|
| 110 |
+
Load the model into memory from file on disk, initialize it with config from the model data,
|
| 111 |
+
and prepare for inferencing using hardware accelerated Torch device.
|
| 112 |
+
"""
|
| 113 |
+
self.logger.debug("Loading ONNX model for inference...")
|
| 114 |
+
|
| 115 |
+
if self.segment_size == self.dim_t:
|
| 116 |
+
ort_session_options = ort.SessionOptions()
|
| 117 |
+
if self.log_level > 10:
|
| 118 |
+
ort_session_options.log_severity_level = 3
|
| 119 |
+
else:
|
| 120 |
+
ort_session_options.log_severity_level = 0
|
| 121 |
+
|
| 122 |
+
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
|
| 123 |
+
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
|
| 124 |
+
self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.")
|
| 125 |
+
else:
|
| 126 |
+
if platform.system() == 'Windows':
|
| 127 |
+
onnx_model = onnx.load(self.model_path)
|
| 128 |
+
self.model_run = onnx2torch.convert(onnx_model)
|
| 129 |
+
else:
|
| 130 |
+
self.model_run = onnx2torch.convert(self.model_path)
|
| 131 |
+
|
| 132 |
+
self.model_run.to(self.torch_device).eval()
|
| 133 |
+
self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")
|
| 134 |
+
|
| 135 |
+
def separate(self, audio_file_path, custom_output_names=None):
|
| 136 |
+
"""
|
| 137 |
+
Separates the audio file into primary and secondary sources based on the model's configuration.
|
| 138 |
+
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
audio_file_path (str): The path to the audio file to be processed.
|
| 142 |
+
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
list: A list of paths to the output files generated by the separation process.
|
| 146 |
+
"""
|
| 147 |
+
self.audio_file_path = audio_file_path
|
| 148 |
+
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
| 149 |
+
|
| 150 |
+
# Prepare the mix for processing
|
| 151 |
+
self.logger.debug(f"Preparing mix for input audio file {self.audio_file_path}...")
|
| 152 |
+
mix = self.prepare_mix(self.audio_file_path)
|
| 153 |
+
|
| 154 |
+
self.logger.debug("Normalizing mix before demixing...")
|
| 155 |
+
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
|
| 156 |
+
|
| 157 |
+
# Start the demixing process
|
| 158 |
+
source = self.demix(mix)
|
| 159 |
+
self.logger.debug("Demixing completed.")
|
| 160 |
+
|
| 161 |
+
# In UVR, the source is cached here if it's a vocal split model, but we're not supporting that yet
|
| 162 |
+
|
| 163 |
+
# Initialize the list for output files
|
| 164 |
+
output_files = []
|
| 165 |
+
self.logger.debug("Processing output files...")
|
| 166 |
+
|
| 167 |
+
# Normalize and transpose the primary source if it's not already an array
|
| 168 |
+
if not isinstance(self.primary_source, np.ndarray):
|
| 169 |
+
self.logger.debug("Normalizing primary source...")
|
| 170 |
+
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T
|
| 171 |
+
|
| 172 |
+
# Process the secondary source if not already an array
|
| 173 |
+
if not isinstance(self.secondary_source, np.ndarray):
|
| 174 |
+
self.logger.debug("Producing secondary source: demixing in match_mix mode")
|
| 175 |
+
raw_mix = self.demix(mix, is_match_mix=True)
|
| 176 |
+
|
| 177 |
+
if self.invert_using_spec:
|
| 178 |
+
self.logger.debug("Inverting secondary stem using spectogram as invert_using_spec is set to True")
|
| 179 |
+
self.secondary_source = spec_utils.invert_stem(raw_mix, source)
|
| 180 |
+
else:
|
| 181 |
+
self.logger.debug("Inverting secondary stem by subtracting of transposed demixed stem from transposed original mix")
|
| 182 |
+
self.secondary_source = mix.T - source.T
|
| 183 |
+
|
| 184 |
+
# Save and process the secondary stem if needed
|
| 185 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
| 186 |
+
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
|
| 187 |
+
|
| 188 |
+
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
|
| 189 |
+
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
| 190 |
+
output_files.append(self.secondary_stem_output_path)
|
| 191 |
+
|
| 192 |
+
# Save and process the primary stem if needed
|
| 193 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
| 194 |
+
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
|
| 195 |
+
|
| 196 |
+
if not isinstance(self.primary_source, np.ndarray):
|
| 197 |
+
self.primary_source = source.T
|
| 198 |
+
|
| 199 |
+
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
|
| 200 |
+
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
| 201 |
+
output_files.append(self.primary_stem_output_path)
|
| 202 |
+
|
| 203 |
+
# Not yet implemented from UVR features:
|
| 204 |
+
# self.process_vocal_split_chain(secondary_sources)
|
| 205 |
+
# self.logger.debug("Vocal split chain processed.")
|
| 206 |
+
|
| 207 |
+
return output_files
|
| 208 |
+
|
| 209 |
+
def initialize_model_settings(self):
|
| 210 |
+
"""
|
| 211 |
+
This function sets up the necessary parameters for the model, like the number of frequency bins (n_bins), the trimming size (trim),
|
| 212 |
+
the size of each audio chunk (chunk_size), and the window function for spectral transformations (window).
|
| 213 |
+
It ensures that the model is configured with the correct settings for processing the audio data.
|
| 214 |
+
"""
|
| 215 |
+
self.logger.debug("Initializing model settings...")
|
| 216 |
+
|
| 217 |
+
# n_bins is half the FFT size plus one (self.n_fft // 2 + 1).
|
| 218 |
+
self.n_bins = self.n_fft // 2 + 1
|
| 219 |
+
|
| 220 |
+
# trim is half the FFT size (self.n_fft // 2).
|
| 221 |
+
self.trim = self.n_fft // 2
|
| 222 |
+
|
| 223 |
+
# chunk_size is the hop_length size times the segment size minus one
|
| 224 |
+
self.chunk_size = self.hop_length * (self.segment_size - 1)
|
| 225 |
+
|
| 226 |
+
# gen_size is the chunk size minus twice the trim size
|
| 227 |
+
self.gen_size = self.chunk_size - 2 * self.trim
|
| 228 |
+
|
| 229 |
+
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
|
| 230 |
+
|
| 231 |
+
self.logger.debug(f"Model input params: n_fft={self.n_fft} hop_length={self.hop_length} dim_f={self.dim_f}")
|
| 232 |
+
self.logger.debug(f"Model settings: n_bins={self.n_bins}, trim={self.trim}, chunk_size={self.chunk_size}, gen_size={self.gen_size}")
|
| 233 |
+
|
| 234 |
+
def initialize_mix(self, mix, is_ckpt=False):
|
| 235 |
+
"""
|
| 236 |
+
After prepare_mix segments the audio, initialize_mix further processes each segment.
|
| 237 |
+
It ensures each audio segment is in the correct format for the model, applies necessary padding,
|
| 238 |
+
and converts the segments into tensors for processing with the model.
|
| 239 |
+
This step is essential for preparing the audio data in a format that the neural network can process.
|
| 240 |
+
"""
|
| 241 |
+
# Log the initialization of the mix and whether checkpoint mode is used
|
| 242 |
+
self.logger.debug(f"Initializing mix with is_ckpt={is_ckpt}. Initial mix shape: {mix.shape}")
|
| 243 |
+
|
| 244 |
+
# Ensure the mix is a 2-channel (stereo) audio signal
|
| 245 |
+
if mix.shape[0] != 2:
|
| 246 |
+
error_message = f"Expected a 2-channel audio signal, but got {mix.shape[0]} channels"
|
| 247 |
+
self.logger.error(error_message)
|
| 248 |
+
raise ValueError(error_message)
|
| 249 |
+
|
| 250 |
+
# If in checkpoint mode, process the mix differently
|
| 251 |
+
if is_ckpt:
|
| 252 |
+
self.logger.debug("Processing in checkpoint mode...")
|
| 253 |
+
# Calculate padding based on the generation size and trim
|
| 254 |
+
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
|
| 255 |
+
self.logger.debug(f"Padding calculated: {pad}")
|
| 256 |
+
# Add padding at the beginning and the end of the mix
|
| 257 |
+
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
|
| 258 |
+
# Determine the number of chunks based on the mixture's length
|
| 259 |
+
num_chunks = mixture.shape[-1] // self.gen_size
|
| 260 |
+
self.logger.debug(f"Mixture shape after padding: {mixture.shape}, Number of chunks: {num_chunks}")
|
| 261 |
+
# Split the mixture into chunks
|
| 262 |
+
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
|
| 263 |
+
else:
|
| 264 |
+
# If not in checkpoint mode, process normally
|
| 265 |
+
self.logger.debug("Processing in non-checkpoint mode...")
|
| 266 |
+
mix_waves = []
|
| 267 |
+
n_sample = mix.shape[1]
|
| 268 |
+
# Calculate necessary padding to make the total length divisible by the generation size
|
| 269 |
+
pad = self.gen_size - n_sample % self.gen_size
|
| 270 |
+
self.logger.debug(f"Number of samples: {n_sample}, Padding calculated: {pad}")
|
| 271 |
+
# Apply padding to the mix
|
| 272 |
+
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
|
| 273 |
+
self.logger.debug(f"Shape of mix after padding: {mix_p.shape}")
|
| 274 |
+
|
| 275 |
+
# Process the mix in chunks
|
| 276 |
+
i = 0
|
| 277 |
+
while i < n_sample + pad:
|
| 278 |
+
waves = np.array(mix_p[:, i : i + self.chunk_size])
|
| 279 |
+
mix_waves.append(waves)
|
| 280 |
+
self.logger.debug(f"Processed chunk {len(mix_waves)}: Start {i}, End {i + self.chunk_size}")
|
| 281 |
+
i += self.gen_size
|
| 282 |
+
|
| 283 |
+
# Convert the list of wave chunks into a tensor for processing on the specified device
|
| 284 |
+
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
|
| 285 |
+
self.logger.debug(f"Converted mix_waves to tensor. Tensor shape: {mix_waves_tensor.shape}")
|
| 286 |
+
|
| 287 |
+
return mix_waves_tensor, pad
|
| 288 |
+
|
| 289 |
+
def demix(self, mix, is_match_mix=False):
|
| 290 |
+
"""
|
| 291 |
+
Demixes the input mix into its constituent sources. If is_match_mix is True, the function adjusts the processing
|
| 292 |
+
to better match the mix, affecting chunk sizes and overlaps. The demixing process involves padding the mix,
|
| 293 |
+
processing it in chunks, applying windowing for overlaps, and accumulating the results to separate the sources.
|
| 294 |
+
"""
|
| 295 |
+
self.logger.debug(f"Starting demixing process with is_match_mix: {is_match_mix}...")
|
| 296 |
+
self.initialize_model_settings()
|
| 297 |
+
|
| 298 |
+
# Preserves the original mix for later use.
|
| 299 |
+
# In UVR, this is used for the pitch fix and VR denoise processes, which aren't yet implemented here.
|
| 300 |
+
org_mix = mix
|
| 301 |
+
self.logger.debug(f"Original mix stored. Shape: {org_mix.shape}")
|
| 302 |
+
|
| 303 |
+
# Initializes a list to store the separated waveforms.
|
| 304 |
+
tar_waves_ = []
|
| 305 |
+
|
| 306 |
+
# Handling different chunk sizes and overlaps based on the matching requirement.
|
| 307 |
+
if is_match_mix:
|
| 308 |
+
# Sets a smaller chunk size specifically for matching the mix.
|
| 309 |
+
chunk_size = self.hop_length * (self.segment_size - 1)
|
| 310 |
+
# Sets a small overlap for the chunks.
|
| 311 |
+
overlap = 0.02
|
| 312 |
+
self.logger.debug(f"Chunk size for matching mix: {chunk_size}, Overlap: {overlap}")
|
| 313 |
+
else:
|
| 314 |
+
# Uses the regular chunk size defined in model settings.
|
| 315 |
+
chunk_size = self.chunk_size
|
| 316 |
+
# Uses the overlap specified in the model settings.
|
| 317 |
+
overlap = self.overlap
|
| 318 |
+
self.logger.debug(f"Standard chunk size: {chunk_size}, Overlap: {overlap}")
|
| 319 |
+
|
| 320 |
+
# Calculates the generated size after subtracting the trim from both ends of the chunk.
|
| 321 |
+
gen_size = chunk_size - 2 * self.trim
|
| 322 |
+
self.logger.debug(f"Generated size calculated: {gen_size}")
|
| 323 |
+
|
| 324 |
+
# Calculates padding to make the mix length a multiple of the generated size.
|
| 325 |
+
pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size)
|
| 326 |
+
# Prepares the mixture with padding at the beginning and the end.
|
| 327 |
+
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
|
| 328 |
+
self.logger.debug(f"Mixture prepared with padding. Mixture shape: {mixture.shape}")
|
| 329 |
+
|
| 330 |
+
# Calculates the step size for processing chunks based on the overlap.
|
| 331 |
+
step = int((1 - overlap) * chunk_size)
|
| 332 |
+
self.logger.debug(f"Step size for processing chunks: {step} as overlap is set to {overlap}.")
|
| 333 |
+
|
| 334 |
+
# Initializes arrays to store the results and to account for overlap.
|
| 335 |
+
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
| 336 |
+
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
| 337 |
+
|
| 338 |
+
# Initializes counters for processing chunks.
|
| 339 |
+
total = 0
|
| 340 |
+
total_chunks = (mixture.shape[-1] + step - 1) // step
|
| 341 |
+
self.logger.debug(f"Total chunks to process: {total_chunks}")
|
| 342 |
+
|
| 343 |
+
# Processes each chunk of the mixture.
|
| 344 |
+
for i in tqdm(range(0, mixture.shape[-1], step)):
|
| 345 |
+
total += 1
|
| 346 |
+
start = i
|
| 347 |
+
end = min(i + chunk_size, mixture.shape[-1])
|
| 348 |
+
self.logger.debug(f"Processing chunk {total}/{total_chunks}: Start {start}, End {end}")
|
| 349 |
+
|
| 350 |
+
# Handles windowing for overlapping chunks.
|
| 351 |
+
chunk_size_actual = end - start
|
| 352 |
+
window = None
|
| 353 |
+
if overlap != 0:
|
| 354 |
+
window = np.hanning(chunk_size_actual)
|
| 355 |
+
window = np.tile(window[None, None, :], (1, 2, 1))
|
| 356 |
+
self.logger.debug("Window applied to the chunk.")
|
| 357 |
+
|
| 358 |
+
# Zero-pad the chunk to prepare it for processing.
|
| 359 |
+
mix_part_ = mixture[:, start:end]
|
| 360 |
+
if end != i + chunk_size:
|
| 361 |
+
pad_size = (i + chunk_size) - end
|
| 362 |
+
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
|
| 363 |
+
|
| 364 |
+
# Converts the chunk to a tensor for processing.
|
| 365 |
+
mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device)
|
| 366 |
+
# Splits the chunk into smaller batches if necessary.
|
| 367 |
+
mix_waves = mix_part.split(self.batch_size)
|
| 368 |
+
total_batches = len(mix_waves)
|
| 369 |
+
self.logger.debug(f"Mix part split into batches. Number of batches: {total_batches}")
|
| 370 |
+
|
| 371 |
+
with torch.no_grad():
|
| 372 |
+
# Processes each batch in the chunk.
|
| 373 |
+
batches_processed = 0
|
| 374 |
+
for mix_wave in mix_waves:
|
| 375 |
+
batches_processed += 1
|
| 376 |
+
self.logger.debug(f"Processing mix_wave batch {batches_processed}/{total_batches}")
|
| 377 |
+
|
| 378 |
+
# Runs the model to separate the sources.
|
| 379 |
+
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
|
| 380 |
+
|
| 381 |
+
# Applies windowing if needed and accumulates the results.
|
| 382 |
+
if window is not None:
|
| 383 |
+
tar_waves[..., :chunk_size_actual] *= window
|
| 384 |
+
divider[..., start:end] += window
|
| 385 |
+
else:
|
| 386 |
+
divider[..., start:end] += 1
|
| 387 |
+
|
| 388 |
+
result[..., start:end] += tar_waves[..., : end - start]
|
| 389 |
+
|
| 390 |
+
# Normalizes the results by the divider to account for overlap.
|
| 391 |
+
self.logger.debug("Normalizing result by dividing result by divider.")
|
| 392 |
+
tar_waves = result / divider
|
| 393 |
+
tar_waves_.append(tar_waves)
|
| 394 |
+
|
| 395 |
+
# Reshapes the results to match the original dimensions.
|
| 396 |
+
tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim]
|
| 397 |
+
tar_waves = np.concatenate(tar_waves_, axis=-1)[:, : mix.shape[-1]]
|
| 398 |
+
|
| 399 |
+
# Extracts the source from the results.
|
| 400 |
+
source = tar_waves[:, 0:None]
|
| 401 |
+
self.logger.debug(f"Concatenated tar_waves. Shape: {tar_waves.shape}")
|
| 402 |
+
|
| 403 |
+
# TODO: In UVR, pitch changing happens here. Consider implementing this as a feature.
|
| 404 |
+
|
| 405 |
+
# Compensates the source if not matching the mix.
|
| 406 |
+
if not is_match_mix:
|
| 407 |
+
source *= self.compensate
|
| 408 |
+
self.logger.debug("Match mix mode; compensate multiplier applied.")
|
| 409 |
+
|
| 410 |
+
# TODO: In UVR, VR denoise model gets applied here. Consider implementing this as a feature.
|
| 411 |
+
|
| 412 |
+
self.logger.debug("Demixing process completed.")
|
| 413 |
+
return source
|
| 414 |
+
|
| 415 |
+
def run_model(self, mix, is_match_mix=False):
|
| 416 |
+
"""
|
| 417 |
+
Processes the input mix through the model to separate the sources.
|
| 418 |
+
Applies STFT, handles spectrum modifications, and runs the model for source separation.
|
| 419 |
+
"""
|
| 420 |
+
# Applying the STFT to the mix. The mix is moved to the specified device (e.g., GPU) before processing.
|
| 421 |
+
# self.logger.debug(f"Running STFT on the mix. Mix shape before STFT: {mix.shape}")
|
| 422 |
+
spek = self.stft(mix.to(self.torch_device))
|
| 423 |
+
self.logger.debug(f"STFT applied on mix. Spectrum shape: {spek.shape}")
|
| 424 |
+
|
| 425 |
+
# Zeroing out the first 3 bins of the spectrum. This is often done to reduce low-frequency noise.
|
| 426 |
+
spek[:, :, :3, :] *= 0
|
| 427 |
+
# self.logger.debug("First 3 bins of the spectrum zeroed out.")
|
| 428 |
+
|
| 429 |
+
# Handling the case where the mix needs to be matched (is_match_mix = True)
|
| 430 |
+
if is_match_mix:
|
| 431 |
+
# self.logger.debug("Match mix mode is enabled. Converting spectrum to NumPy array.")
|
| 432 |
+
spec_pred = spek.cpu().numpy()
|
| 433 |
+
self.logger.debug("is_match_mix: spectrum prediction obtained directly from STFT output.")
|
| 434 |
+
else:
|
| 435 |
+
# If denoising is enabled, the model is run on both the negative and positive spectrums.
|
| 436 |
+
if self.enable_denoise:
|
| 437 |
+
# Assuming spek is a tensor and self.model_run can process it directly
|
| 438 |
+
spec_pred_neg = self.model_run(-spek) # Ensure this line correctly negates spek and runs the model
|
| 439 |
+
spec_pred_pos = self.model_run(spek)
|
| 440 |
+
# Ensure both spec_pred_neg and spec_pred_pos are tensors before applying operations
|
| 441 |
+
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5) # [invalid-unary-operand-type]
|
| 442 |
+
self.logger.debug("Model run on both negative and positive spectrums for denoising.")
|
| 443 |
+
else:
|
| 444 |
+
spec_pred = self.model_run(spek)
|
| 445 |
+
self.logger.debug("Model run on the spectrum without denoising.")
|
| 446 |
+
|
| 447 |
+
# Applying the inverse STFT to convert the spectrum back to the time domain.
|
| 448 |
+
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
|
| 449 |
+
self.logger.debug(f"Inverse STFT applied. Returning result with shape: {result.shape}")
|
| 450 |
+
|
| 451 |
+
return result
|
audio_separator/separator/architectures/mdxc_separator.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from ml_collections import ConfigDict
|
| 8 |
+
from scipy import signal
|
| 9 |
+
|
| 10 |
+
from audio_separator.separator.common_separator import CommonSeparator
|
| 11 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
| 12 |
+
from audio_separator.separator.uvr_lib_v5.tfc_tdf_v3 import TFC_TDF_net
|
| 13 |
+
from audio_separator.separator.uvr_lib_v5.roformer.mel_band_roformer import MelBandRoformer
|
| 14 |
+
from audio_separator.separator.uvr_lib_v5.roformer.bs_roformer import BSRoformer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MDXCSeparator(CommonSeparator):
|
| 18 |
+
"""
|
| 19 |
+
MDXCSeparator is responsible for separating audio sources using MDXC models.
|
| 20 |
+
It initializes with configuration parameters and prepares the model for separation tasks.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, common_config, arch_config):
|
| 24 |
+
# Any configuration values which can be shared between architectures should be set already in CommonSeparator,
|
| 25 |
+
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
|
| 26 |
+
super().__init__(config=common_config)
|
| 27 |
+
|
| 28 |
+
# Model data is basic overview metadata about the model, e.g. which stem is primary and whether it's a karaoke model
|
| 29 |
+
# It's loaded in from model_data_new.json in Separator.load_model and there are JSON examples in that method
|
| 30 |
+
# The instance variable self.model_data is passed through from Separator and set in CommonSeparator
|
| 31 |
+
self.logger.debug(f"Model data: {self.model_data}")
|
| 32 |
+
|
| 33 |
+
# Arch Config is the MDXC architecture specific user configuration options, which should all be configurable by the user
|
| 34 |
+
# either by their Separator class instantiation or by passing in a CLI parameter.
|
| 35 |
+
# While there are similarities between architectures for some of these (e.g. batch_size), they are deliberately configured
|
| 36 |
+
# this way as they have architecture-specific default values.
|
| 37 |
+
self.segment_size = arch_config.get("segment_size", 256)
|
| 38 |
+
|
| 39 |
+
# Whether or not to use the segment size from model config, or the default
|
| 40 |
+
# The segment size is set based on the value provided in a chosen model's associated config file (yaml).
|
| 41 |
+
self.override_model_segment_size = arch_config.get("override_model_segment_size", False)
|
| 42 |
+
|
| 43 |
+
self.overlap = arch_config.get("overlap", 8)
|
| 44 |
+
self.batch_size = arch_config.get("batch_size", 1)
|
| 45 |
+
|
| 46 |
+
# Amount of pitch shift to apply during processing (this does NOT affect the pitch of the output audio):
|
| 47 |
+
# • Whole numbers indicate semitones.
|
| 48 |
+
# • Using higher pitches may cut the upper bandwidth, even in high-quality models.
|
| 49 |
+
# • Upping the pitch can be better for tracks with deeper vocals.
|
| 50 |
+
# • Dropping the pitch may take more processing time but works well for tracks with high-pitched vocals.
|
| 51 |
+
self.pitch_shift = arch_config.get("pitch_shift", 0)
|
| 52 |
+
|
| 53 |
+
self.process_all_stems = arch_config.get("process_all_stems", True)
|
| 54 |
+
|
| 55 |
+
self.logger.debug(f"MDXC arch params: batch_size={self.batch_size}, segment_size={self.segment_size}, overlap={self.overlap}")
|
| 56 |
+
self.logger.debug(f"MDXC arch params: override_model_segment_size={self.override_model_segment_size}, pitch_shift={self.pitch_shift}")
|
| 57 |
+
self.logger.debug(f"MDXC multi-stem params: process_all_stems={self.process_all_stems}")
|
| 58 |
+
|
| 59 |
+
self.is_roformer = "is_roformer" in self.model_data
|
| 60 |
+
|
| 61 |
+
self.load_model()
|
| 62 |
+
|
| 63 |
+
self.primary_source = None
|
| 64 |
+
self.secondary_source = None
|
| 65 |
+
self.audio_file_path = None
|
| 66 |
+
self.audio_file_base = None
|
| 67 |
+
|
| 68 |
+
self.is_primary_stem_main_target = False
|
| 69 |
+
if self.model_data_cfgdict.training.target_instrument == "Vocals" or len(self.model_data_cfgdict.training.instruments) > 1:
|
| 70 |
+
self.is_primary_stem_main_target = True
|
| 71 |
+
|
| 72 |
+
self.logger.debug(f"is_primary_stem_main_target: {self.is_primary_stem_main_target}")
|
| 73 |
+
|
| 74 |
+
self.logger.info("MDXC Separator initialisation complete")
|
| 75 |
+
|
| 76 |
+
def load_model(self):
|
| 77 |
+
"""
|
| 78 |
+
Load the model into memory from file on disk, initialize it with config from the model data,
|
| 79 |
+
and prepare for inferencing using hardware accelerated Torch device.
|
| 80 |
+
"""
|
| 81 |
+
self.logger.debug("Loading checkpoint model for inference...")
|
| 82 |
+
|
| 83 |
+
self.model_data_cfgdict = ConfigDict(self.model_data)
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
if self.is_roformer:
|
| 87 |
+
self.logger.debug("Loading Roformer model...")
|
| 88 |
+
|
| 89 |
+
# Determine the model type based on the configuration and instantiate it
|
| 90 |
+
if "num_bands" in self.model_data_cfgdict.model:
|
| 91 |
+
self.logger.debug("Loading MelBandRoformer model...")
|
| 92 |
+
model = MelBandRoformer(**self.model_data_cfgdict.model)
|
| 93 |
+
elif "freqs_per_bands" in self.model_data_cfgdict.model:
|
| 94 |
+
self.logger.debug("Loading BSRoformer model...")
|
| 95 |
+
model = BSRoformer(**self.model_data_cfgdict.model)
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError("Unknown Roformer model type in the configuration.")
|
| 98 |
+
|
| 99 |
+
# Load model checkpoint
|
| 100 |
+
checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True)
|
| 101 |
+
self.model_run = model if not isinstance(model, torch.nn.DataParallel) else model.module
|
| 102 |
+
self.model_run.load_state_dict(checkpoint)
|
| 103 |
+
self.model_run.to(self.torch_device).eval()
|
| 104 |
+
|
| 105 |
+
else:
|
| 106 |
+
self.logger.debug("Loading TFC_TDF_net model...")
|
| 107 |
+
self.model_run = TFC_TDF_net(self.model_data_cfgdict, device=self.torch_device)
|
| 108 |
+
self.logger.debug("Loading model onto cpu")
|
| 109 |
+
# For some reason loading the state onto a hardware accelerated devices causes issues,
|
| 110 |
+
# so we load it onto CPU first then move it to the device
|
| 111 |
+
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
|
| 112 |
+
self.model_run.to(self.torch_device).eval()
|
| 113 |
+
|
| 114 |
+
except RuntimeError as e:
|
| 115 |
+
self.logger.error(f"Error: {e}")
|
| 116 |
+
self.logger.error("An error occurred while loading the model file. This often occurs when the model file is corrupt or incomplete.")
|
| 117 |
+
self.logger.error(f"Please try deleting the model file from {self.model_path} and run audio-separator again to re-download it.")
|
| 118 |
+
sys.exit(1)
|
| 119 |
+
|
| 120 |
+
def separate(self, audio_file_path, custom_output_names=None):
|
| 121 |
+
"""
|
| 122 |
+
Separates the audio file into primary and secondary sources based on the model's configuration.
|
| 123 |
+
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
audio_file_path (str): The path to the audio file to be processed.
|
| 127 |
+
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
list: A list of paths to the output files generated by the separation process.
|
| 131 |
+
"""
|
| 132 |
+
self.primary_source = None
|
| 133 |
+
self.secondary_source = None
|
| 134 |
+
|
| 135 |
+
self.audio_file_path = audio_file_path
|
| 136 |
+
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
| 137 |
+
|
| 138 |
+
self.logger.debug(f"Preparing mix for input audio file {self.audio_file_path}...")
|
| 139 |
+
mix = self.prepare_mix(self.audio_file_path)
|
| 140 |
+
|
| 141 |
+
self.logger.debug("Normalizing mix before demixing...")
|
| 142 |
+
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
|
| 143 |
+
|
| 144 |
+
source = self.demix(mix=mix)
|
| 145 |
+
self.logger.debug("Demixing completed.")
|
| 146 |
+
|
| 147 |
+
output_files = []
|
| 148 |
+
self.logger.debug("Processing output files...")
|
| 149 |
+
|
| 150 |
+
if isinstance(source, dict):
|
| 151 |
+
self.logger.debug("Source is a dict, processing each stem...")
|
| 152 |
+
|
| 153 |
+
stem_list = []
|
| 154 |
+
if self.model_data_cfgdict.training.target_instrument:
|
| 155 |
+
stem_list = [self.model_data_cfgdict.training.target_instrument]
|
| 156 |
+
else:
|
| 157 |
+
stem_list = self.model_data_cfgdict.training.instruments
|
| 158 |
+
|
| 159 |
+
self.logger.debug(f"Available stems: {stem_list}")
|
| 160 |
+
|
| 161 |
+
is_multi_stem_model = len(stem_list) > 2
|
| 162 |
+
should_process_all_stems = self.process_all_stems and is_multi_stem_model
|
| 163 |
+
|
| 164 |
+
if should_process_all_stems:
|
| 165 |
+
self.logger.debug("Processing all stems from multi-stem model...")
|
| 166 |
+
for stem_name in stem_list:
|
| 167 |
+
stem_output_path = self.get_stem_output_path(stem_name, custom_output_names)
|
| 168 |
+
stem_source = spec_utils.normalize(
|
| 169 |
+
wave=source[stem_name],
|
| 170 |
+
max_peak=self.normalization_threshold,
|
| 171 |
+
min_peak=self.amplification_threshold
|
| 172 |
+
).T
|
| 173 |
+
|
| 174 |
+
self.logger.info(f"Saving {stem_name} stem to {stem_output_path}...")
|
| 175 |
+
self.final_process(stem_output_path, stem_source, stem_name)
|
| 176 |
+
output_files.append(stem_output_path)
|
| 177 |
+
else:
|
| 178 |
+
# Standard processing for primary and secondary stems
|
| 179 |
+
if not isinstance(self.primary_source, np.ndarray):
|
| 180 |
+
self.logger.debug(f"Normalizing primary source for primary stem {self.primary_stem_name}...")
|
| 181 |
+
self.primary_source = spec_utils.normalize(
|
| 182 |
+
wave=source[self.primary_stem_name],
|
| 183 |
+
max_peak=self.normalization_threshold,
|
| 184 |
+
min_peak=self.amplification_threshold
|
| 185 |
+
).T
|
| 186 |
+
|
| 187 |
+
if not isinstance(self.secondary_source, np.ndarray):
|
| 188 |
+
self.logger.debug(f"Normalizing secondary source for secondary stem {self.secondary_stem_name}...")
|
| 189 |
+
self.secondary_source = spec_utils.normalize(
|
| 190 |
+
wave=source[self.secondary_stem_name],
|
| 191 |
+
max_peak=self.normalization_threshold,
|
| 192 |
+
min_peak=self.amplification_threshold
|
| 193 |
+
).T
|
| 194 |
+
|
| 195 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
| 196 |
+
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
|
| 197 |
+
|
| 198 |
+
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
|
| 199 |
+
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
| 200 |
+
output_files.append(self.secondary_stem_output_path)
|
| 201 |
+
|
| 202 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
| 203 |
+
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
|
| 204 |
+
|
| 205 |
+
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
|
| 206 |
+
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
| 207 |
+
output_files.append(self.primary_stem_output_path)
|
| 208 |
+
|
| 209 |
+
else:
|
| 210 |
+
# Handle case when source is not a dictionary (single source model)
|
| 211 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
| 212 |
+
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
|
| 213 |
+
|
| 214 |
+
if not isinstance(self.primary_source, np.ndarray):
|
| 215 |
+
self.primary_source = source.T
|
| 216 |
+
|
| 217 |
+
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
|
| 218 |
+
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
| 219 |
+
output_files.append(self.primary_stem_output_path)
|
| 220 |
+
|
| 221 |
+
return output_files
|
| 222 |
+
|
| 223 |
+
def pitch_fix(self, source, sr_pitched, orig_mix):
|
| 224 |
+
"""
|
| 225 |
+
Change the pitch of the source audio by a number of semitones.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
source (np.ndarray): The source audio to be pitch-shifted.
|
| 229 |
+
sr_pitched (int): The sample rate of the pitch-shifted audio.
|
| 230 |
+
orig_mix (np.ndarray): The original mix, used to match the shape of the pitch-shifted audio.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
np.ndarray: The pitch-shifted source audio.
|
| 234 |
+
"""
|
| 235 |
+
source = spec_utils.change_pitch_semitones(source, sr_pitched, semitone_shift=self.pitch_shift)[0]
|
| 236 |
+
source = spec_utils.match_array_shapes(source, orig_mix)
|
| 237 |
+
return source
|
| 238 |
+
|
| 239 |
+
def overlap_add(self, result, x, weights, start, length):
|
| 240 |
+
"""
|
| 241 |
+
Adds the overlapping part of the result to the result tensor.
|
| 242 |
+
"""
|
| 243 |
+
result[..., start : start + length] += x[..., :length] * weights[:length]
|
| 244 |
+
return result
|
| 245 |
+
|
| 246 |
+
def demix(self, mix: np.ndarray) -> dict:
|
| 247 |
+
"""
|
| 248 |
+
Demixes the input mix into primary and secondary sources using the model and model data.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
mix (np.ndarray): The mix to be demixed.
|
| 252 |
+
Returns:
|
| 253 |
+
dict: A dictionary containing the demixed sources.
|
| 254 |
+
"""
|
| 255 |
+
orig_mix = mix
|
| 256 |
+
|
| 257 |
+
if self.pitch_shift != 0:
|
| 258 |
+
self.logger.debug(f"Shifting pitch by -{self.pitch_shift} semitones...")
|
| 259 |
+
mix, sample_rate = spec_utils.change_pitch_semitones(mix, self.sample_rate, semitone_shift=-self.pitch_shift)
|
| 260 |
+
|
| 261 |
+
if self.is_roformer:
|
| 262 |
+
# Note: Currently, for Roformer models, `batch_size` is not utilized due to negligible performance improvements.
|
| 263 |
+
|
| 264 |
+
mix = torch.tensor(mix, dtype=torch.float32)
|
| 265 |
+
|
| 266 |
+
if self.override_model_segment_size:
|
| 267 |
+
mdx_segment_size = self.segment_size
|
| 268 |
+
self.logger.debug(f"Using configured segment size: {mdx_segment_size}")
|
| 269 |
+
else:
|
| 270 |
+
mdx_segment_size = self.model_data_cfgdict.inference.dim_t
|
| 271 |
+
self.logger.debug(f"Using model default segment size: {mdx_segment_size}")
|
| 272 |
+
|
| 273 |
+
# num_stems aka "S" in UVR
|
| 274 |
+
num_stems = 1 if self.model_data_cfgdict.training.target_instrument else len(self.model_data_cfgdict.training.instruments)
|
| 275 |
+
self.logger.debug(f"Number of stems: {num_stems}")
|
| 276 |
+
|
| 277 |
+
# chunk_size aka "C" in UVR
|
| 278 |
+
chunk_size = self.model_data_cfgdict.audio.hop_length * (mdx_segment_size - 1)
|
| 279 |
+
self.logger.debug(f"Chunk size: {chunk_size}")
|
| 280 |
+
|
| 281 |
+
step = int(self.overlap * self.model_data_cfgdict.audio.sample_rate)
|
| 282 |
+
self.logger.debug(f"Step: {step}")
|
| 283 |
+
|
| 284 |
+
# Create a weighting table and convert it to a PyTorch tensor
|
| 285 |
+
window = torch.tensor(signal.windows.hamming(chunk_size), dtype=torch.float32)
|
| 286 |
+
|
| 287 |
+
device = next(self.model_run.parameters()).device
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
with torch.no_grad():
|
| 291 |
+
req_shape = (len(self.model_data_cfgdict.training.instruments),) + tuple(mix.shape)
|
| 292 |
+
result = torch.zeros(req_shape, dtype=torch.float32)
|
| 293 |
+
counter = torch.zeros(req_shape, dtype=torch.float32)
|
| 294 |
+
|
| 295 |
+
for i in tqdm(range(0, mix.shape[1], step)):
|
| 296 |
+
part = mix[:, i : i + chunk_size]
|
| 297 |
+
length = part.shape[-1]
|
| 298 |
+
if i + chunk_size > mix.shape[1]:
|
| 299 |
+
part = mix[:, -chunk_size:]
|
| 300 |
+
length = chunk_size
|
| 301 |
+
part = part.to(device)
|
| 302 |
+
x = self.model_run(part.unsqueeze(0))[0]
|
| 303 |
+
x = x.cpu()
|
| 304 |
+
# Perform overlap_add on CPU
|
| 305 |
+
if i + chunk_size > mix.shape[1]:
|
| 306 |
+
# Fixed to correctly add to the end of the tensor
|
| 307 |
+
result = self.overlap_add(result, x, window, result.shape[-1] - chunk_size, length)
|
| 308 |
+
counter[..., result.shape[-1] - chunk_size :] += window[:length]
|
| 309 |
+
else:
|
| 310 |
+
result = self.overlap_add(result, x, window, i, length)
|
| 311 |
+
counter[..., i : i + length] += window[:length]
|
| 312 |
+
|
| 313 |
+
inferenced_outputs = result / counter.clamp(min=1e-10)
|
| 314 |
+
|
| 315 |
+
else:
|
| 316 |
+
mix = torch.tensor(mix, dtype=torch.float32)
|
| 317 |
+
|
| 318 |
+
try:
|
| 319 |
+
num_stems = self.model_run.num_target_instruments
|
| 320 |
+
except AttributeError:
|
| 321 |
+
num_stems = self.model_run.module.num_target_instruments
|
| 322 |
+
self.logger.debug(f"Number of stems: {num_stems}")
|
| 323 |
+
|
| 324 |
+
if self.override_model_segment_size:
|
| 325 |
+
mdx_segment_size = self.segment_size
|
| 326 |
+
self.logger.debug(f"Using configured segment size: {mdx_segment_size}")
|
| 327 |
+
else:
|
| 328 |
+
mdx_segment_size = self.model_data_cfgdict.inference.dim_t
|
| 329 |
+
self.logger.debug(f"Using model default segment size: {mdx_segment_size}")
|
| 330 |
+
|
| 331 |
+
chunk_size = self.model_data_cfgdict.audio.hop_length * (mdx_segment_size - 1)
|
| 332 |
+
self.logger.debug(f"Chunk size: {chunk_size}")
|
| 333 |
+
|
| 334 |
+
hop_size = chunk_size // self.overlap
|
| 335 |
+
self.logger.debug(f"Hop size: {hop_size}")
|
| 336 |
+
|
| 337 |
+
mix_shape = mix.shape[1]
|
| 338 |
+
pad_size = hop_size - (mix_shape - chunk_size) % hop_size
|
| 339 |
+
self.logger.debug(f"Pad size: {pad_size}")
|
| 340 |
+
|
| 341 |
+
mix = torch.cat([torch.zeros(2, chunk_size - hop_size), mix, torch.zeros(2, pad_size + chunk_size - hop_size)], 1)
|
| 342 |
+
self.logger.debug(f"Mix shape: {mix.shape}")
|
| 343 |
+
|
| 344 |
+
chunks = mix.unfold(1, chunk_size, hop_size).transpose(0, 1)
|
| 345 |
+
self.logger.debug(f"Chunks length: {len(chunks)} and shape: {chunks.shape}")
|
| 346 |
+
|
| 347 |
+
batches = [chunks[i : i + self.batch_size] for i in range(0, len(chunks), self.batch_size)]
|
| 348 |
+
self.logger.debug(f"Batch size: {self.batch_size}, number of batches: {len(batches)}")
|
| 349 |
+
|
| 350 |
+
# accumulated_outputs is used to accumulate the output from processing each batch of chunks through the model.
|
| 351 |
+
# It starts as a tensor of zeros and is updated in-place as the model processes each batch.
|
| 352 |
+
# The variable holds the combined result of all processed batches, which, after post-processing, represents the separated audio sources.
|
| 353 |
+
accumulated_outputs = torch.zeros(num_stems, *mix.shape) if num_stems > 1 else torch.zeros_like(mix)
|
| 354 |
+
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
count = 0
|
| 357 |
+
for batch in tqdm(batches):
|
| 358 |
+
# Since the model processes the audio data in batches, single_batch_result temporarily holds the model's output
|
| 359 |
+
# for each batch before it is accumulated into accumulated_outputs.
|
| 360 |
+
single_batch_result = self.model_run(batch.to(self.torch_device))
|
| 361 |
+
|
| 362 |
+
# Each individual output tensor from the current batch's processing result.
|
| 363 |
+
# Since single_batch_result can contain multiple output tensors (one for each piece of audio in the batch),
|
| 364 |
+
# individual_output is used to iterate through these tensors and accumulate them into accumulated_outputs.
|
| 365 |
+
for individual_output in single_batch_result:
|
| 366 |
+
individual_output_cpu = individual_output.cpu()
|
| 367 |
+
# Accumulate outputs on CPU
|
| 368 |
+
accumulated_outputs[..., count * hop_size : count * hop_size + chunk_size] += individual_output_cpu
|
| 369 |
+
count += 1
|
| 370 |
+
|
| 371 |
+
self.logger.debug("Calculating inferenced outputs based on accumulated outputs and overlap")
|
| 372 |
+
inferenced_outputs = accumulated_outputs[..., chunk_size - hop_size : -(pad_size + chunk_size - hop_size)] / self.overlap
|
| 373 |
+
self.logger.debug("Deleting accumulated outputs to free up memory")
|
| 374 |
+
del accumulated_outputs
|
| 375 |
+
|
| 376 |
+
if num_stems > 1 or self.is_primary_stem_main_target:
|
| 377 |
+
self.logger.debug("Number of stems is greater than 1 or vocals are main target, detaching individual sources and correcting pitch if necessary...")
|
| 378 |
+
|
| 379 |
+
sources = {}
|
| 380 |
+
|
| 381 |
+
# Iterates over each instrument specified in the model's configuration and its corresponding separated audio source.
|
| 382 |
+
# self.model_data_cfgdict.training.instruments provides the list of stems.
|
| 383 |
+
# estimated_sources.cpu().detach().numpy() converts the separated sources tensor to a NumPy array for processing.
|
| 384 |
+
# Each iteration provides an instrument name ('key') and its separated audio ('value') for further processing.
|
| 385 |
+
for key, value in zip(self.model_data_cfgdict.training.instruments, inferenced_outputs.cpu().detach().numpy()):
|
| 386 |
+
self.logger.debug(f"Processing instrument: {key}")
|
| 387 |
+
if self.pitch_shift != 0:
|
| 388 |
+
self.logger.debug(f"Applying pitch correction for {key}")
|
| 389 |
+
sources[key] = self.pitch_fix(value, sample_rate, orig_mix)
|
| 390 |
+
else:
|
| 391 |
+
sources[key] = value
|
| 392 |
+
|
| 393 |
+
if self.is_primary_stem_main_target:
|
| 394 |
+
self.logger.debug(f"Primary stem: {self.primary_stem_name} is main target, detaching and matching array shapes if necessary...")
|
| 395 |
+
if sources[self.primary_stem_name].shape[1] != orig_mix.shape[1]:
|
| 396 |
+
sources[self.primary_stem_name] = spec_utils.match_array_shapes(sources[self.primary_stem_name], orig_mix)
|
| 397 |
+
sources[self.secondary_stem_name] = orig_mix - sources[self.primary_stem_name]
|
| 398 |
+
|
| 399 |
+
self.logger.debug("Deleting inferenced outputs to free up memory")
|
| 400 |
+
del inferenced_outputs
|
| 401 |
+
|
| 402 |
+
self.logger.debug("Returning separated sources")
|
| 403 |
+
return sources
|
| 404 |
+
else:
|
| 405 |
+
self.logger.debug("Processing single source...")
|
| 406 |
+
|
| 407 |
+
if self.is_roformer:
|
| 408 |
+
sources = {k: v.cpu().detach().numpy() for k, v in zip([self.model_data_cfgdict.training.target_instrument], inferenced_outputs)}
|
| 409 |
+
inferenced_output = sources[self.model_data_cfgdict.training.target_instrument]
|
| 410 |
+
else:
|
| 411 |
+
inferenced_output = inferenced_outputs.cpu().detach().numpy()
|
| 412 |
+
|
| 413 |
+
self.logger.debug("Demix process completed for single source.")
|
| 414 |
+
|
| 415 |
+
self.logger.debug("Deleting inferenced outputs to free up memory")
|
| 416 |
+
del inferenced_outputs
|
| 417 |
+
|
| 418 |
+
if self.pitch_shift != 0:
|
| 419 |
+
self.logger.debug("Applying pitch correction for single instrument")
|
| 420 |
+
return self.pitch_fix(inferenced_output, sample_rate, orig_mix)
|
| 421 |
+
else:
|
| 422 |
+
self.logger.debug("Returning inferenced output for single instrument")
|
| 423 |
+
return inferenced_output
|
audio_separator/separator/architectures/vr_separator.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module for separating audio sources using VR architecture models."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import librosa
|
| 8 |
+
import numpy as np
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
# Check if we really need the rerun_mp3 function, remove if not
|
| 12 |
+
import audioread
|
| 13 |
+
|
| 14 |
+
from audio_separator.separator.common_separator import CommonSeparator
|
| 15 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
| 16 |
+
from audio_separator.separator.uvr_lib_v5.vr_network import nets
|
| 17 |
+
from audio_separator.separator.uvr_lib_v5.vr_network import nets_new
|
| 18 |
+
from audio_separator.separator.uvr_lib_v5.vr_network.model_param_init import ModelParameters
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class VRSeparator(CommonSeparator):
|
| 22 |
+
"""
|
| 23 |
+
VRSeparator is responsible for separating audio sources using VR models.
|
| 24 |
+
It initializes with configuration parameters and prepares the model for separation tasks.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, common_config, arch_config: dict):
|
| 28 |
+
# Any configuration values which can be shared between architectures should be set already in CommonSeparator,
|
| 29 |
+
# e.g. user-specified functionality choices (self.output_single_stem) or common model parameters (self.primary_stem_name)
|
| 30 |
+
super().__init__(config=common_config)
|
| 31 |
+
|
| 32 |
+
# Model data is basic overview metadata about the model, e.g. which stem is primary and whether it's a karaoke model
|
| 33 |
+
# It's loaded in from model_data_new.json in Separator.load_model and there are JSON examples in that method
|
| 34 |
+
# The instance variable self.model_data is passed through from Separator and set in CommonSeparator
|
| 35 |
+
self.logger.debug(f"Model data: {self.model_data}")
|
| 36 |
+
|
| 37 |
+
# Most of the VR models use the same number of output channels, but the VR 51 models have specific values set in model_data JSON
|
| 38 |
+
self.model_capacity = 32, 128
|
| 39 |
+
self.is_vr_51_model = False
|
| 40 |
+
|
| 41 |
+
if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys():
|
| 42 |
+
self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"]
|
| 43 |
+
self.is_vr_51_model = True
|
| 44 |
+
|
| 45 |
+
# Model params are additional technical parameter values from JSON files in separator/uvr_lib_v5/vr_network/modelparams/*.json,
|
| 46 |
+
# with filenames referenced by the model_data["vr_model_param"] value
|
| 47 |
+
package_root_filepath = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 48 |
+
vr_params_json_dir = os.path.join(package_root_filepath, "uvr_lib_v5", "vr_network", "modelparams")
|
| 49 |
+
vr_params_json_filename = f"{self.model_data['vr_model_param']}.json"
|
| 50 |
+
vr_params_json_filepath = os.path.join(vr_params_json_dir, vr_params_json_filename)
|
| 51 |
+
self.model_params = ModelParameters(vr_params_json_filepath)
|
| 52 |
+
|
| 53 |
+
self.logger.debug(f"Model params: {self.model_params.param}")
|
| 54 |
+
|
| 55 |
+
# Arch Config is the VR architecture specific user configuration options, which should all be configurable by the user
|
| 56 |
+
# either by their Separator class instantiation or by passing in a CLI parameter.
|
| 57 |
+
# While there are similarities between architectures for some of these (e.g. batch_size), they are deliberately configured
|
| 58 |
+
# this way as they have architecture-specific default values.
|
| 59 |
+
|
| 60 |
+
# This option performs Test-Time-Augmentation to improve the separation quality.
|
| 61 |
+
# Note: Having this selected will increase the time it takes to complete a conversion
|
| 62 |
+
self.enable_tta = arch_config.get("enable_tta", False)
|
| 63 |
+
|
| 64 |
+
# This option can potentially identify leftover instrumental artifacts within the vocal outputs; may improve the separation of some songs.
|
| 65 |
+
# Note: Selecting this option can adversely affect the conversion process, depending on the track. Because of this, it is only recommended as a last resort.
|
| 66 |
+
self.enable_post_process = arch_config.get("enable_post_process", False)
|
| 67 |
+
|
| 68 |
+
# post_process_threshold values = ('0.1', '0.2', '0.3')
|
| 69 |
+
self.post_process_threshold = arch_config.get("post_process_threshold", 0.2)
|
| 70 |
+
|
| 71 |
+
# Number of batches to be processed at a time.
|
| 72 |
+
# - Higher values mean more RAM usage but slightly faster processing times.
|
| 73 |
+
# - Lower values mean less RAM usage but slightly longer processing times.
|
| 74 |
+
# - Batch size value has no effect on output quality.
|
| 75 |
+
|
| 76 |
+
# Andrew note: for some reason, lower batch sizes seem to cause broken output for VR arch; need to investigate why
|
| 77 |
+
self.batch_size = arch_config.get("batch_size", 1)
|
| 78 |
+
|
| 79 |
+
# Select window size to balance quality and speed:
|
| 80 |
+
# - 1024 - Quick but lesser quality.
|
| 81 |
+
# - 512 - Medium speed and quality.
|
| 82 |
+
# - 320 - Takes longer but may offer better quality.
|
| 83 |
+
self.window_size = arch_config.get("window_size", 512)
|
| 84 |
+
|
| 85 |
+
# The application will mirror the missing frequency range of the output.
|
| 86 |
+
self.high_end_process = arch_config.get("high_end_process", False)
|
| 87 |
+
self.input_high_end_h = None
|
| 88 |
+
self.input_high_end = None
|
| 89 |
+
|
| 90 |
+
# Adjust the intensity of primary stem extraction:
|
| 91 |
+
# - Ranges from -100 - 100.
|
| 92 |
+
# - Bigger values mean deeper extractions.
|
| 93 |
+
# - Typically, it's set to 5 for vocals & instrumentals.
|
| 94 |
+
# - Values beyond 5 might muddy the sound for non-vocal models.
|
| 95 |
+
self.aggression = float(int(arch_config.get("aggression", 5)) / 100)
|
| 96 |
+
|
| 97 |
+
self.aggressiveness = {"value": self.aggression, "split_bin": self.model_params.param["band"][1]["crop_stop"], "aggr_correction": self.model_params.param.get("aggr_correction")}
|
| 98 |
+
|
| 99 |
+
self.model_samplerate = self.model_params.param["sr"]
|
| 100 |
+
|
| 101 |
+
self.logger.debug(f"VR arch params: enable_tta={self.enable_tta}, enable_post_process={self.enable_post_process}, post_process_threshold={self.post_process_threshold}")
|
| 102 |
+
self.logger.debug(f"VR arch params: batch_size={self.batch_size}, window_size={self.window_size}")
|
| 103 |
+
self.logger.debug(f"VR arch params: high_end_process={self.high_end_process}, aggression={self.aggression}")
|
| 104 |
+
self.logger.debug(f"VR arch params: is_vr_51_model={self.is_vr_51_model}, model_samplerate={self.model_samplerate}, model_capacity={self.model_capacity}")
|
| 105 |
+
|
| 106 |
+
self.model_run = lambda *args, **kwargs: self.logger.error("Model run method is not initialised yet.")
|
| 107 |
+
|
| 108 |
+
# This should go away once we refactor to remove soundfile.write and replace with pydub like we did for the MDX rewrite
|
| 109 |
+
self.wav_subtype = "PCM_16"
|
| 110 |
+
|
| 111 |
+
self.logger.info("VR Separator initialisation complete")
|
| 112 |
+
|
| 113 |
+
def separate(self, audio_file_path, custom_output_names=None):
|
| 114 |
+
"""
|
| 115 |
+
Separates the audio file into primary and secondary sources based on the model's configuration.
|
| 116 |
+
It processes the mix, demixes it into sources, normalizes the sources, and saves the output files.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
audio_file_path (str): The path to the audio file to be processed.
|
| 120 |
+
custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
list: A list of paths to the output files generated by the separation process.
|
| 124 |
+
"""
|
| 125 |
+
self.primary_source = None
|
| 126 |
+
self.secondary_source = None
|
| 127 |
+
|
| 128 |
+
self.audio_file_path = audio_file_path
|
| 129 |
+
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
| 130 |
+
|
| 131 |
+
self.logger.debug(f"Starting separation for input audio file {self.audio_file_path}...")
|
| 132 |
+
|
| 133 |
+
nn_arch_sizes = [31191, 33966, 56817, 123821, 123812, 129605, 218409, 537238, 537227] # default
|
| 134 |
+
vr_5_1_models = [56817, 218409]
|
| 135 |
+
model_size = math.ceil(os.stat(self.model_path).st_size / 1024)
|
| 136 |
+
nn_arch_size = min(nn_arch_sizes, key=lambda x: abs(x - model_size))
|
| 137 |
+
self.logger.debug(f"Model size determined: {model_size}, NN architecture size: {nn_arch_size}")
|
| 138 |
+
|
| 139 |
+
if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
|
| 140 |
+
self.logger.debug("Using CascadedNet for VR 5.1 model...")
|
| 141 |
+
self.model_run = nets_new.CascadedNet(self.model_params.param["bins"] * 2, nn_arch_size, nout=self.model_capacity[0], nout_lstm=self.model_capacity[1])
|
| 142 |
+
self.is_vr_51_model = True
|
| 143 |
+
else:
|
| 144 |
+
self.logger.debug("Determining model capacity...")
|
| 145 |
+
self.model_run = nets.determine_model_capacity(self.model_params.param["bins"] * 2, nn_arch_size)
|
| 146 |
+
|
| 147 |
+
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
|
| 148 |
+
self.model_run.to(self.torch_device)
|
| 149 |
+
self.logger.debug("Model loaded and moved to device.")
|
| 150 |
+
|
| 151 |
+
y_spec, v_spec = self.inference_vr(self.loading_mix(), self.torch_device, self.aggressiveness)
|
| 152 |
+
self.logger.debug("Inference completed.")
|
| 153 |
+
|
| 154 |
+
# Sanitize y_spec and v_spec to replace NaN and infinite values
|
| 155 |
+
y_spec = np.nan_to_num(y_spec, nan=0.0, posinf=0.0, neginf=0.0)
|
| 156 |
+
v_spec = np.nan_to_num(v_spec, nan=0.0, posinf=0.0, neginf=0.0)
|
| 157 |
+
|
| 158 |
+
self.logger.debug("Sanitization completed. Replaced NaN and infinite values in y_spec and v_spec.")
|
| 159 |
+
|
| 160 |
+
# After inference_vr call
|
| 161 |
+
self.logger.debug(f"Inference VR completed. y_spec shape: {y_spec.shape}, v_spec shape: {v_spec.shape}")
|
| 162 |
+
self.logger.debug(f"y_spec stats - min: {np.min(y_spec)}, max: {np.max(y_spec)}, isnan: {np.isnan(y_spec).any()}, isinf: {np.isinf(y_spec).any()}")
|
| 163 |
+
self.logger.debug(f"v_spec stats - min: {np.min(v_spec)}, max: {np.max(v_spec)}, isnan: {np.isnan(v_spec).any()}, isinf: {np.isinf(v_spec).any()}")
|
| 164 |
+
|
| 165 |
+
# Not yet implemented from UVR features:
|
| 166 |
+
#
|
| 167 |
+
# if not self.is_vocal_split_model:
|
| 168 |
+
# self.cache_source((y_spec, v_spec))
|
| 169 |
+
|
| 170 |
+
# if self.is_secondary_model_activated and self.secondary_model:
|
| 171 |
+
# self.logger.debug("Processing secondary model...")
|
| 172 |
+
# self.secondary_source_primary, self.secondary_source_secondary = process_secondary_model(
|
| 173 |
+
# self.secondary_model, self.process_data, main_process_method=self.process_method, main_model_primary=self.primary_stem
|
| 174 |
+
# )
|
| 175 |
+
|
| 176 |
+
# Initialize the list for output files
|
| 177 |
+
output_files = []
|
| 178 |
+
self.logger.debug("Processing output files...")
|
| 179 |
+
|
| 180 |
+
# Note: logic similar to the following should probably be added to the other architectures
|
| 181 |
+
# Check if output_single_stem is set to a value that would result in no output files
|
| 182 |
+
if self.output_single_stem and (self.output_single_stem.lower() != self.primary_stem_name.lower() and self.output_single_stem.lower() != self.secondary_stem_name.lower()):
|
| 183 |
+
# If so, reset output_single_stem to None to save both stems
|
| 184 |
+
self.output_single_stem = None
|
| 185 |
+
self.logger.warning(f"The output_single_stem setting '{self.output_single_stem}' does not match any of the output files: '{self.primary_stem_name}' and '{self.secondary_stem_name}'. For this model '{self.model_name}', the output_single_stem setting will be ignored and all output files will be saved.")
|
| 186 |
+
|
| 187 |
+
# Save and process the primary stem if needed
|
| 188 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
| 189 |
+
self.logger.debug(f"Processing primary stem: {self.primary_stem_name}")
|
| 190 |
+
if not isinstance(self.primary_source, np.ndarray):
|
| 191 |
+
self.logger.debug(f"Preparing to convert spectrogram to waveform. Spec shape: {y_spec.shape}")
|
| 192 |
+
|
| 193 |
+
self.primary_source = self.spec_to_wav(y_spec).T
|
| 194 |
+
self.logger.debug("Converting primary source spectrogram to waveform.")
|
| 195 |
+
if not self.model_samplerate == 44100:
|
| 196 |
+
self.primary_source = librosa.resample(self.primary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
|
| 197 |
+
self.logger.debug("Resampling primary source to 44100Hz.")
|
| 198 |
+
|
| 199 |
+
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
|
| 200 |
+
|
| 201 |
+
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
|
| 202 |
+
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
| 203 |
+
output_files.append(self.primary_stem_output_path)
|
| 204 |
+
|
| 205 |
+
# Save and process the secondary stem if needed
|
| 206 |
+
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
| 207 |
+
self.logger.debug(f"Processing secondary stem: {self.secondary_stem_name}")
|
| 208 |
+
if not isinstance(self.secondary_source, np.ndarray):
|
| 209 |
+
self.logger.debug(f"Preparing to convert spectrogram to waveform. Spec shape: {v_spec.shape}")
|
| 210 |
+
|
| 211 |
+
self.secondary_source = self.spec_to_wav(v_spec).T
|
| 212 |
+
self.logger.debug("Converting secondary source spectrogram to waveform.")
|
| 213 |
+
if not self.model_samplerate == 44100:
|
| 214 |
+
self.secondary_source = librosa.resample(self.secondary_source.T, orig_sr=self.model_samplerate, target_sr=44100).T
|
| 215 |
+
self.logger.debug("Resampling secondary source to 44100Hz.")
|
| 216 |
+
|
| 217 |
+
self.secondary_stem_output_path = self.get_stem_output_path(self.secondary_stem_name, custom_output_names)
|
| 218 |
+
|
| 219 |
+
self.logger.info(f"Saving {self.secondary_stem_name} stem to {self.secondary_stem_output_path}...")
|
| 220 |
+
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
| 221 |
+
output_files.append(self.secondary_stem_output_path)
|
| 222 |
+
|
| 223 |
+
# Not yet implemented from UVR features:
|
| 224 |
+
# self.process_vocal_split_chain(secondary_sources)
|
| 225 |
+
# self.logger.debug("Vocal split chain processed.")
|
| 226 |
+
|
| 227 |
+
return output_files
|
| 228 |
+
|
| 229 |
+
def loading_mix(self):
|
| 230 |
+
X_wave, X_spec_s = {}, {}
|
| 231 |
+
|
| 232 |
+
bands_n = len(self.model_params.param["band"])
|
| 233 |
+
|
| 234 |
+
audio_file = spec_utils.write_array_to_mem(self.audio_file_path, subtype=self.wav_subtype)
|
| 235 |
+
is_mp3 = audio_file.endswith(".mp3") if isinstance(audio_file, str) else False
|
| 236 |
+
|
| 237 |
+
self.logger.debug(f"loading_mix iteraring through {bands_n} bands")
|
| 238 |
+
for d in tqdm(range(bands_n, 0, -1)):
|
| 239 |
+
bp = self.model_params.param["band"][d]
|
| 240 |
+
|
| 241 |
+
wav_resolution = bp["res_type"]
|
| 242 |
+
|
| 243 |
+
if self.torch_device_mps is not None:
|
| 244 |
+
wav_resolution = "polyphase"
|
| 245 |
+
|
| 246 |
+
if d == bands_n: # high-end band
|
| 247 |
+
X_wave[d], _ = librosa.load(audio_file, sr=bp["sr"], mono=False, dtype=np.float32, res_type=wav_resolution)
|
| 248 |
+
X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp["hl"], bp["n_fft"], self.model_params, band=d, is_v51_model=self.is_vr_51_model)
|
| 249 |
+
|
| 250 |
+
if not np.any(X_wave[d]) and is_mp3:
|
| 251 |
+
X_wave[d] = rerun_mp3(audio_file, bp["sr"])
|
| 252 |
+
|
| 253 |
+
if X_wave[d].ndim == 1:
|
| 254 |
+
X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
|
| 255 |
+
else: # lower bands
|
| 256 |
+
X_wave[d] = librosa.resample(X_wave[d + 1], orig_sr=self.model_params.param["band"][d + 1]["sr"], target_sr=bp["sr"], res_type=wav_resolution)
|
| 257 |
+
X_spec_s[d] = spec_utils.wave_to_spectrogram(X_wave[d], bp["hl"], bp["n_fft"], self.model_params, band=d, is_v51_model=self.is_vr_51_model)
|
| 258 |
+
|
| 259 |
+
if d == bands_n and self.high_end_process:
|
| 260 |
+
self.input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (self.model_params.param["pre_filter_stop"] - self.model_params.param["pre_filter_start"])
|
| 261 |
+
self.input_high_end = X_spec_s[d][:, bp["n_fft"] // 2 - self.input_high_end_h : bp["n_fft"] // 2, :]
|
| 262 |
+
|
| 263 |
+
X_spec = spec_utils.combine_spectrograms(X_spec_s, self.model_params, is_v51_model=self.is_vr_51_model)
|
| 264 |
+
|
| 265 |
+
del X_wave, X_spec_s, audio_file
|
| 266 |
+
|
| 267 |
+
return X_spec
|
| 268 |
+
|
| 269 |
+
def inference_vr(self, X_spec, device, aggressiveness):
|
| 270 |
+
def _execute(X_mag_pad, roi_size):
|
| 271 |
+
X_dataset = []
|
| 272 |
+
patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size
|
| 273 |
+
|
| 274 |
+
self.logger.debug(f"inference_vr appending to X_dataset for each of {patches} patches")
|
| 275 |
+
for i in tqdm(range(patches)):
|
| 276 |
+
start = i * roi_size
|
| 277 |
+
X_mag_window = X_mag_pad[:, :, start : start + self.window_size]
|
| 278 |
+
X_dataset.append(X_mag_window)
|
| 279 |
+
|
| 280 |
+
total_iterations = patches // self.batch_size if not self.enable_tta else (patches // self.batch_size) * 2
|
| 281 |
+
self.logger.debug(f"inference_vr iterating through {total_iterations} batches, batch_size = {self.batch_size}")
|
| 282 |
+
|
| 283 |
+
X_dataset = np.asarray(X_dataset)
|
| 284 |
+
self.model_run.eval()
|
| 285 |
+
with torch.no_grad():
|
| 286 |
+
mask = []
|
| 287 |
+
|
| 288 |
+
for i in tqdm(range(0, patches, self.batch_size)):
|
| 289 |
+
|
| 290 |
+
X_batch = X_dataset[i : i + self.batch_size]
|
| 291 |
+
X_batch = torch.from_numpy(X_batch).to(device)
|
| 292 |
+
pred = self.model_run.predict_mask(X_batch)
|
| 293 |
+
if not pred.size()[3] > 0:
|
| 294 |
+
raise ValueError(f"Window size error: h1_shape[3] must be greater than h2_shape[3]")
|
| 295 |
+
pred = pred.detach().cpu().numpy()
|
| 296 |
+
pred = np.concatenate(pred, axis=2)
|
| 297 |
+
mask.append(pred)
|
| 298 |
+
if len(mask) == 0:
|
| 299 |
+
raise ValueError(f"Window size error: h1_shape[3] must be greater than h2_shape[3]")
|
| 300 |
+
|
| 301 |
+
mask = np.concatenate(mask, axis=2)
|
| 302 |
+
return mask
|
| 303 |
+
|
| 304 |
+
def postprocess(mask, X_mag, X_phase):
|
| 305 |
+
is_non_accom_stem = False
|
| 306 |
+
for stem in CommonSeparator.NON_ACCOM_STEMS:
|
| 307 |
+
if stem == self.primary_stem_name:
|
| 308 |
+
is_non_accom_stem = True
|
| 309 |
+
|
| 310 |
+
mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness)
|
| 311 |
+
|
| 312 |
+
if self.enable_post_process:
|
| 313 |
+
mask = spec_utils.merge_artifacts(mask, thres=self.post_process_threshold)
|
| 314 |
+
|
| 315 |
+
y_spec = mask * X_mag * np.exp(1.0j * X_phase)
|
| 316 |
+
v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase)
|
| 317 |
+
|
| 318 |
+
return y_spec, v_spec
|
| 319 |
+
|
| 320 |
+
X_mag, X_phase = spec_utils.preprocess(X_spec)
|
| 321 |
+
n_frame = X_mag.shape[2]
|
| 322 |
+
pad_l, pad_r, roi_size = spec_utils.make_padding(n_frame, self.window_size, self.model_run.offset)
|
| 323 |
+
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
|
| 324 |
+
X_mag_pad /= X_mag_pad.max()
|
| 325 |
+
mask = _execute(X_mag_pad, roi_size)
|
| 326 |
+
|
| 327 |
+
if self.enable_tta:
|
| 328 |
+
pad_l += roi_size // 2
|
| 329 |
+
pad_r += roi_size // 2
|
| 330 |
+
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
|
| 331 |
+
X_mag_pad /= X_mag_pad.max()
|
| 332 |
+
mask_tta = _execute(X_mag_pad, roi_size)
|
| 333 |
+
mask_tta = mask_tta[:, :, roi_size // 2 :]
|
| 334 |
+
mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5
|
| 335 |
+
else:
|
| 336 |
+
mask = mask[:, :, :n_frame]
|
| 337 |
+
|
| 338 |
+
y_spec, v_spec = postprocess(mask, X_mag, X_phase)
|
| 339 |
+
|
| 340 |
+
return y_spec, v_spec
|
| 341 |
+
|
| 342 |
+
def spec_to_wav(self, spec):
|
| 343 |
+
if self.high_end_process and isinstance(self.input_high_end, np.ndarray) and self.input_high_end_h:
|
| 344 |
+
input_high_end_ = spec_utils.mirroring("mirroring", spec, self.input_high_end, self.model_params)
|
| 345 |
+
wav = spec_utils.cmb_spectrogram_to_wave(spec, self.model_params, self.input_high_end_h, input_high_end_, is_v51_model=self.is_vr_51_model)
|
| 346 |
+
else:
|
| 347 |
+
wav = spec_utils.cmb_spectrogram_to_wave(spec, self.model_params, is_v51_model=self.is_vr_51_model)
|
| 348 |
+
|
| 349 |
+
return wav
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# Check if we really need the rerun_mp3 function, refactor or remove if not
|
| 353 |
+
def rerun_mp3(audio_file, sample_rate=44100):
|
| 354 |
+
with audioread.audio_open(audio_file) as f:
|
| 355 |
+
track_length = int(f.duration)
|
| 356 |
+
|
| 357 |
+
return librosa.load(audio_file, duration=track_length, mono=False, sr=sample_rate)[0]
|
audio_separator/separator/common_separator.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" This file contains the CommonSeparator class, common to all architecture-specific Separator classes. """
|
| 2 |
+
|
| 3 |
+
from logging import Logger
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import gc
|
| 7 |
+
import numpy as np
|
| 8 |
+
import librosa
|
| 9 |
+
import torch
|
| 10 |
+
from pydub import AudioSegment
|
| 11 |
+
import soundfile as sf
|
| 12 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CommonSeparator:
|
| 16 |
+
"""
|
| 17 |
+
This class contains the common methods and attributes common to all architecture-specific Separator classes.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
ALL_STEMS = "All Stems"
|
| 21 |
+
VOCAL_STEM = "Vocals"
|
| 22 |
+
INST_STEM = "Instrumental"
|
| 23 |
+
OTHER_STEM = "Other"
|
| 24 |
+
BASS_STEM = "Bass"
|
| 25 |
+
DRUM_STEM = "Drums"
|
| 26 |
+
GUITAR_STEM = "Guitar"
|
| 27 |
+
PIANO_STEM = "Piano"
|
| 28 |
+
SYNTH_STEM = "Synthesizer"
|
| 29 |
+
STRINGS_STEM = "Strings"
|
| 30 |
+
WOODWINDS_STEM = "Woodwinds"
|
| 31 |
+
BRASS_STEM = "Brass"
|
| 32 |
+
WIND_INST_STEM = "Wind Inst"
|
| 33 |
+
NO_OTHER_STEM = "No Other"
|
| 34 |
+
NO_BASS_STEM = "No Bass"
|
| 35 |
+
NO_DRUM_STEM = "No Drums"
|
| 36 |
+
NO_GUITAR_STEM = "No Guitar"
|
| 37 |
+
NO_PIANO_STEM = "No Piano"
|
| 38 |
+
NO_SYNTH_STEM = "No Synthesizer"
|
| 39 |
+
NO_STRINGS_STEM = "No Strings"
|
| 40 |
+
NO_WOODWINDS_STEM = "No Woodwinds"
|
| 41 |
+
NO_WIND_INST_STEM = "No Wind Inst"
|
| 42 |
+
NO_BRASS_STEM = "No Brass"
|
| 43 |
+
PRIMARY_STEM = "Primary Stem"
|
| 44 |
+
SECONDARY_STEM = "Secondary Stem"
|
| 45 |
+
LEAD_VOCAL_STEM = "lead_only"
|
| 46 |
+
BV_VOCAL_STEM = "backing_only"
|
| 47 |
+
LEAD_VOCAL_STEM_I = "with_lead_vocals"
|
| 48 |
+
BV_VOCAL_STEM_I = "with_backing_vocals"
|
| 49 |
+
LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
|
| 50 |
+
BV_VOCAL_STEM_LABEL = "Backing Vocals"
|
| 51 |
+
NO_STEM = "No "
|
| 52 |
+
|
| 53 |
+
STEM_PAIR_MAPPER = {VOCAL_STEM: INST_STEM, INST_STEM: VOCAL_STEM, LEAD_VOCAL_STEM: BV_VOCAL_STEM, BV_VOCAL_STEM: LEAD_VOCAL_STEM, PRIMARY_STEM: SECONDARY_STEM}
|
| 54 |
+
|
| 55 |
+
NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM)
|
| 56 |
+
|
| 57 |
+
def __init__(self, config):
|
| 58 |
+
|
| 59 |
+
self.logger: Logger = config.get("logger")
|
| 60 |
+
self.log_level: int = config.get("log_level")
|
| 61 |
+
|
| 62 |
+
# Inferencing device / acceleration config
|
| 63 |
+
self.torch_device = config.get("torch_device")
|
| 64 |
+
self.torch_device_cpu = config.get("torch_device_cpu")
|
| 65 |
+
self.torch_device_mps = config.get("torch_device_mps")
|
| 66 |
+
self.onnx_execution_provider = config.get("onnx_execution_provider")
|
| 67 |
+
|
| 68 |
+
# Model data
|
| 69 |
+
self.model_name = config.get("model_name")
|
| 70 |
+
self.model_path = config.get("model_path")
|
| 71 |
+
self.model_data = config.get("model_data")
|
| 72 |
+
|
| 73 |
+
# Output directory and format
|
| 74 |
+
self.output_dir = config.get("output_dir")
|
| 75 |
+
self.output_format = config.get("output_format")
|
| 76 |
+
self.output_bitrate = config.get("output_bitrate")
|
| 77 |
+
|
| 78 |
+
# Functional options which are applicable to all architectures and the user may tweak to affect the output
|
| 79 |
+
self.normalization_threshold = config.get("normalization_threshold")
|
| 80 |
+
self.amplification_threshold = config.get("amplification_threshold")
|
| 81 |
+
self.enable_denoise = config.get("enable_denoise")
|
| 82 |
+
self.output_single_stem = config.get("output_single_stem")
|
| 83 |
+
self.invert_using_spec = config.get("invert_using_spec")
|
| 84 |
+
self.sample_rate = config.get("sample_rate")
|
| 85 |
+
self.use_soundfile = config.get("use_soundfile")
|
| 86 |
+
|
| 87 |
+
# Model specific properties
|
| 88 |
+
|
| 89 |
+
# Check if model_data has a "training" key with "instruments" list
|
| 90 |
+
self.primary_stem_name = None
|
| 91 |
+
self.secondary_stem_name = None
|
| 92 |
+
|
| 93 |
+
if "training" in self.model_data and "instruments" in self.model_data["training"]:
|
| 94 |
+
instruments = self.model_data["training"]["instruments"]
|
| 95 |
+
if instruments:
|
| 96 |
+
self.primary_stem_name = instruments[0]
|
| 97 |
+
self.secondary_stem_name = instruments[1] if len(instruments) > 1 else self.secondary_stem(self.primary_stem_name)
|
| 98 |
+
|
| 99 |
+
if self.primary_stem_name is None:
|
| 100 |
+
self.primary_stem_name = self.model_data.get("primary_stem", "Vocals")
|
| 101 |
+
self.secondary_stem_name = self.secondary_stem(self.primary_stem_name)
|
| 102 |
+
|
| 103 |
+
self.is_karaoke = self.model_data.get("is_karaoke", False)
|
| 104 |
+
self.is_bv_model = self.model_data.get("is_bv_model", False)
|
| 105 |
+
self.bv_model_rebalance = self.model_data.get("is_bv_model_rebalanced", 0)
|
| 106 |
+
|
| 107 |
+
self.logger.debug(f"Common params: model_name={self.model_name}, model_path={self.model_path}")
|
| 108 |
+
self.logger.debug(f"Common params: output_dir={self.output_dir}, output_format={self.output_format}")
|
| 109 |
+
self.logger.debug(f"Common params: normalization_threshold={self.normalization_threshold}, amplification_threshold={self.amplification_threshold}")
|
| 110 |
+
self.logger.debug(f"Common params: enable_denoise={self.enable_denoise}, output_single_stem={self.output_single_stem}")
|
| 111 |
+
self.logger.debug(f"Common params: invert_using_spec={self.invert_using_spec}, sample_rate={self.sample_rate}")
|
| 112 |
+
|
| 113 |
+
self.logger.debug(f"Common params: primary_stem_name={self.primary_stem_name}, secondary_stem_name={self.secondary_stem_name}")
|
| 114 |
+
self.logger.debug(f"Common params: is_karaoke={self.is_karaoke}, is_bv_model={self.is_bv_model}, bv_model_rebalance={self.bv_model_rebalance}")
|
| 115 |
+
|
| 116 |
+
# File-specific variables which need to be cleared between processing different audio inputs
|
| 117 |
+
self.audio_file_path = None
|
| 118 |
+
self.audio_file_base = None
|
| 119 |
+
|
| 120 |
+
self.primary_source = None
|
| 121 |
+
self.secondary_source = None
|
| 122 |
+
|
| 123 |
+
self.primary_stem_output_path = None
|
| 124 |
+
self.secondary_stem_output_path = None
|
| 125 |
+
|
| 126 |
+
self.cached_sources_map = {}
|
| 127 |
+
|
| 128 |
+
def secondary_stem(self, primary_stem: str):
|
| 129 |
+
"""Determines secondary stem name based on the primary stem name."""
|
| 130 |
+
primary_stem = primary_stem if primary_stem else self.NO_STEM
|
| 131 |
+
|
| 132 |
+
if primary_stem in self.STEM_PAIR_MAPPER:
|
| 133 |
+
secondary_stem = self.STEM_PAIR_MAPPER[primary_stem]
|
| 134 |
+
else:
|
| 135 |
+
secondary_stem = primary_stem.replace(self.NO_STEM, "") if self.NO_STEM in primary_stem else f"{self.NO_STEM}{primary_stem}"
|
| 136 |
+
|
| 137 |
+
return secondary_stem
|
| 138 |
+
|
| 139 |
+
def separate(self, audio_file_path):
|
| 140 |
+
"""
|
| 141 |
+
Placeholder method for separating audio sources. Should be overridden by subclasses.
|
| 142 |
+
"""
|
| 143 |
+
raise NotImplementedError("This method should be overridden by subclasses.")
|
| 144 |
+
|
| 145 |
+
def final_process(self, stem_path, source, stem_name):
|
| 146 |
+
"""
|
| 147 |
+
Finalizes the processing of a stem by writing the audio to a file and returning the processed source.
|
| 148 |
+
"""
|
| 149 |
+
self.logger.debug(f"Finalizing {stem_name} stem processing and writing audio...")
|
| 150 |
+
self.write_audio(stem_path, source)
|
| 151 |
+
|
| 152 |
+
return {stem_name: source}
|
| 153 |
+
|
| 154 |
+
def cached_sources_clear(self):
|
| 155 |
+
"""
|
| 156 |
+
Clears the cache dictionaries for VR, MDX, and Demucs models.
|
| 157 |
+
|
| 158 |
+
This function is essential for ensuring that the cache does not hold outdated or irrelevant data
|
| 159 |
+
between different processing sessions or when a new batch of audio files is processed.
|
| 160 |
+
It helps in managing memory efficiently and prevents potential errors due to stale data.
|
| 161 |
+
"""
|
| 162 |
+
self.cached_sources_map = {}
|
| 163 |
+
|
| 164 |
+
def cached_source_callback(self, model_architecture, model_name=None):
|
| 165 |
+
"""
|
| 166 |
+
Retrieves the model and sources from the cache based on the processing method and model name.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
model_architecture: The architecture type (VR, MDX, or Demucs) being used for processing.
|
| 170 |
+
model_name: The specific model name within the architecture type, if applicable.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
A tuple containing the model and its sources if found in the cache; otherwise, None.
|
| 174 |
+
|
| 175 |
+
This function is crucial for optimizing performance by avoiding redundant processing.
|
| 176 |
+
If the requested model and its sources are already in the cache, they can be reused directly,
|
| 177 |
+
saving time and computational resources.
|
| 178 |
+
"""
|
| 179 |
+
model, sources = None, None
|
| 180 |
+
|
| 181 |
+
mapper = self.cached_sources_map[model_architecture]
|
| 182 |
+
|
| 183 |
+
for key, value in mapper.items():
|
| 184 |
+
if model_name in key:
|
| 185 |
+
model = key
|
| 186 |
+
sources = value
|
| 187 |
+
|
| 188 |
+
return model, sources
|
| 189 |
+
|
| 190 |
+
def cached_model_source_holder(self, model_architecture, sources, model_name=None):
|
| 191 |
+
"""
|
| 192 |
+
Update the dictionary for the given model_architecture with the new model name and its sources.
|
| 193 |
+
Use the model_architecture as a key to access the corresponding cache source mapper dictionary.
|
| 194 |
+
"""
|
| 195 |
+
self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}}
|
| 196 |
+
|
| 197 |
+
def prepare_mix(self, mix):
|
| 198 |
+
"""
|
| 199 |
+
Prepares the mix for processing. This includes loading the audio from a file if necessary,
|
| 200 |
+
ensuring the mix is in the correct format, and converting mono to stereo if needed.
|
| 201 |
+
"""
|
| 202 |
+
# Store the original path or the mix itself for later checks
|
| 203 |
+
audio_path = mix
|
| 204 |
+
|
| 205 |
+
# Check if the input is a file path (string) and needs to be loaded
|
| 206 |
+
if not isinstance(mix, np.ndarray):
|
| 207 |
+
self.logger.debug(f"Loading audio from file: {mix}")
|
| 208 |
+
mix, sr = librosa.load(mix, mono=False, sr=self.sample_rate)
|
| 209 |
+
self.logger.debug(f"Audio loaded. Sample rate: {sr}, Audio shape: {mix.shape}")
|
| 210 |
+
else:
|
| 211 |
+
# Transpose the mix if it's already an ndarray (expected shape: [channels, samples])
|
| 212 |
+
self.logger.debug("Transposing the provided mix array.")
|
| 213 |
+
mix = mix.T
|
| 214 |
+
self.logger.debug(f"Transposed mix shape: {mix.shape}")
|
| 215 |
+
|
| 216 |
+
# If the original input was a filepath, check if the loaded mix is empty
|
| 217 |
+
if isinstance(audio_path, str):
|
| 218 |
+
if not np.any(mix):
|
| 219 |
+
error_msg = f"Audio file {audio_path} is empty or not valid"
|
| 220 |
+
self.logger.error(error_msg)
|
| 221 |
+
raise ValueError(error_msg)
|
| 222 |
+
else:
|
| 223 |
+
self.logger.debug("Audio file is valid and contains data.")
|
| 224 |
+
|
| 225 |
+
# Ensure the mix is in stereo format
|
| 226 |
+
if mix.ndim == 1:
|
| 227 |
+
self.logger.debug("Mix is mono. Converting to stereo.")
|
| 228 |
+
mix = np.asfortranarray([mix, mix])
|
| 229 |
+
self.logger.debug("Converted to stereo mix.")
|
| 230 |
+
|
| 231 |
+
# Final log indicating successful preparation of the mix
|
| 232 |
+
self.logger.debug("Mix preparation completed.")
|
| 233 |
+
return mix
|
| 234 |
+
|
| 235 |
+
def write_audio(self, stem_path: str, stem_source):
|
| 236 |
+
"""
|
| 237 |
+
Writes the separated audio source to a file using pydub or soundfile
|
| 238 |
+
Pydub supports a much wider range of audio formats and produces better encoded lossy files for some formats.
|
| 239 |
+
Soundfile is used for very large files (longer than 1 hour), as pydub has memory issues with large files:
|
| 240 |
+
https://github.com/jiaaro/pydub/issues/135
|
| 241 |
+
"""
|
| 242 |
+
# Get the duration of the input audio file
|
| 243 |
+
duration_seconds = librosa.get_duration(filename=self.audio_file_path)
|
| 244 |
+
duration_hours = duration_seconds / 3600
|
| 245 |
+
self.logger.info(f"Audio duration is {duration_hours:.2f} hours ({duration_seconds:.2f} seconds).")
|
| 246 |
+
|
| 247 |
+
if self.use_soundfile:
|
| 248 |
+
self.logger.warning(f"Using soundfile for writing.")
|
| 249 |
+
self.write_audio_soundfile(stem_path, stem_source)
|
| 250 |
+
else:
|
| 251 |
+
self.logger.info(f"Using pydub for writing.")
|
| 252 |
+
self.write_audio_pydub(stem_path, stem_source)
|
| 253 |
+
|
| 254 |
+
def write_audio_pydub(self, stem_path: str, stem_source):
|
| 255 |
+
"""
|
| 256 |
+
Writes the separated audio source to a file using pydub (ffmpeg)
|
| 257 |
+
"""
|
| 258 |
+
self.logger.debug(f"Entering write_audio_pydub with stem_path: {stem_path}")
|
| 259 |
+
|
| 260 |
+
stem_source = spec_utils.normalize(wave=stem_source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
|
| 261 |
+
|
| 262 |
+
# Check if the numpy array is empty or contains very low values
|
| 263 |
+
if np.max(np.abs(stem_source)) < 1e-6:
|
| 264 |
+
self.logger.warning("Warning: stem_source array is near-silent or empty.")
|
| 265 |
+
return
|
| 266 |
+
|
| 267 |
+
# If output_dir is specified, create it and join it with stem_path
|
| 268 |
+
if self.output_dir:
|
| 269 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 270 |
+
stem_path = os.path.join(self.output_dir, stem_path)
|
| 271 |
+
|
| 272 |
+
self.logger.debug(f"Audio data shape before processing: {stem_source.shape}")
|
| 273 |
+
self.logger.debug(f"Data type before conversion: {stem_source.dtype}")
|
| 274 |
+
|
| 275 |
+
# Ensure the audio data is in the correct format (e.g., int16)
|
| 276 |
+
if stem_source.dtype != np.int16:
|
| 277 |
+
stem_source = (stem_source * 32767).astype(np.int16)
|
| 278 |
+
self.logger.debug("Converted stem_source to int16.")
|
| 279 |
+
|
| 280 |
+
# Correctly interleave stereo channels
|
| 281 |
+
stem_source_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
|
| 282 |
+
stem_source_interleaved[0::2] = stem_source[:, 0] # Left channel
|
| 283 |
+
stem_source_interleaved[1::2] = stem_source[:, 1] # Right channel
|
| 284 |
+
|
| 285 |
+
self.logger.debug(f"Interleaved audio data shape: {stem_source_interleaved.shape}")
|
| 286 |
+
|
| 287 |
+
# Create a pydub AudioSegment
|
| 288 |
+
try:
|
| 289 |
+
audio_segment = AudioSegment(stem_source_interleaved.tobytes(), frame_rate=self.sample_rate, sample_width=stem_source.dtype.itemsize, channels=2)
|
| 290 |
+
self.logger.debug("Created AudioSegment successfully.")
|
| 291 |
+
except (IOError, ValueError) as e:
|
| 292 |
+
self.logger.error(f"Specific error creating AudioSegment: {e}")
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
# Determine file format based on the file extension
|
| 296 |
+
file_format = stem_path.lower().split(".")[-1]
|
| 297 |
+
|
| 298 |
+
# For m4a files, specify mp4 as the container format as the extension doesn't match the format name
|
| 299 |
+
if file_format == "m4a":
|
| 300 |
+
file_format = "mp4"
|
| 301 |
+
elif file_format == "mka":
|
| 302 |
+
file_format = "matroska"
|
| 303 |
+
|
| 304 |
+
# Set the bitrate to 320k for mp3 files if output_bitrate is not specified
|
| 305 |
+
bitrate = "320k" if file_format == "mp3" and self.output_bitrate is None else self.output_bitrate
|
| 306 |
+
|
| 307 |
+
# Export using the determined format
|
| 308 |
+
try:
|
| 309 |
+
audio_segment.export(stem_path, format=file_format, bitrate=bitrate)
|
| 310 |
+
self.logger.debug(f"Exported audio file successfully to {stem_path}")
|
| 311 |
+
except (IOError, ValueError) as e:
|
| 312 |
+
self.logger.error(f"Error exporting audio file: {e}")
|
| 313 |
+
|
| 314 |
+
def write_audio_soundfile(self, stem_path: str, stem_source):
|
| 315 |
+
"""
|
| 316 |
+
Writes the separated audio source to a file using soundfile library.
|
| 317 |
+
"""
|
| 318 |
+
self.logger.debug(f"Entering write_audio_soundfile with stem_path: {stem_path}")
|
| 319 |
+
|
| 320 |
+
# Correctly interleave stereo channels if needed
|
| 321 |
+
if stem_source.shape[1] == 2:
|
| 322 |
+
# If the audio is already interleaved, ensure it's in the correct order
|
| 323 |
+
# Check if the array is Fortran contiguous (column-major)
|
| 324 |
+
if stem_source.flags["F_CONTIGUOUS"]:
|
| 325 |
+
# Convert to C contiguous (row-major)
|
| 326 |
+
stem_source = np.ascontiguousarray(stem_source)
|
| 327 |
+
# Otherwise, perform interleaving
|
| 328 |
+
else:
|
| 329 |
+
stereo_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
|
| 330 |
+
# Left channel
|
| 331 |
+
stereo_interleaved[0::2] = stem_source[:, 0]
|
| 332 |
+
# Right channel
|
| 333 |
+
stereo_interleaved[1::2] = stem_source[:, 1]
|
| 334 |
+
stem_source = stereo_interleaved
|
| 335 |
+
|
| 336 |
+
self.logger.debug(f"Interleaved audio data shape: {stem_source.shape}")
|
| 337 |
+
|
| 338 |
+
"""
|
| 339 |
+
Write audio using soundfile (for formats other than M4A).
|
| 340 |
+
"""
|
| 341 |
+
# Save audio using soundfile
|
| 342 |
+
try:
|
| 343 |
+
# Specify the subtype to define the sample width
|
| 344 |
+
sf.write(stem_path, stem_source, self.sample_rate)
|
| 345 |
+
self.logger.debug(f"Exported audio file successfully to {stem_path}")
|
| 346 |
+
except Exception as e:
|
| 347 |
+
self.logger.error(f"Error exporting audio file: {e}")
|
| 348 |
+
|
| 349 |
+
def clear_gpu_cache(self):
|
| 350 |
+
"""
|
| 351 |
+
This method clears the GPU cache to free up memory.
|
| 352 |
+
"""
|
| 353 |
+
self.logger.debug("Running garbage collection...")
|
| 354 |
+
gc.collect()
|
| 355 |
+
if self.torch_device == torch.device("mps"):
|
| 356 |
+
self.logger.debug("Clearing MPS cache...")
|
| 357 |
+
torch.mps.empty_cache()
|
| 358 |
+
if self.torch_device == torch.device("cuda"):
|
| 359 |
+
self.logger.debug("Clearing CUDA cache...")
|
| 360 |
+
torch.cuda.empty_cache()
|
| 361 |
+
|
| 362 |
+
def clear_file_specific_paths(self):
|
| 363 |
+
"""
|
| 364 |
+
Clears the file-specific variables which need to be cleared between processing different audio inputs.
|
| 365 |
+
"""
|
| 366 |
+
self.logger.info("Clearing input audio file paths, sources and stems...")
|
| 367 |
+
|
| 368 |
+
self.audio_file_path = None
|
| 369 |
+
self.audio_file_base = None
|
| 370 |
+
|
| 371 |
+
self.primary_source = None
|
| 372 |
+
self.secondary_source = None
|
| 373 |
+
|
| 374 |
+
self.primary_stem_output_path = None
|
| 375 |
+
self.secondary_stem_output_path = None
|
| 376 |
+
|
| 377 |
+
def sanitize_filename(self, filename):
|
| 378 |
+
"""
|
| 379 |
+
Cleans the filename by replacing invalid characters with underscores.
|
| 380 |
+
"""
|
| 381 |
+
sanitized = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
| 382 |
+
sanitized = re.sub(r'_+', '_', sanitized)
|
| 383 |
+
sanitized = sanitized.strip('_. ')
|
| 384 |
+
return sanitized
|
| 385 |
+
|
| 386 |
+
def get_stem_output_path(self, stem_name, custom_output_names):
|
| 387 |
+
"""
|
| 388 |
+
Gets the output path for a stem based on the stem name and custom output names.
|
| 389 |
+
"""
|
| 390 |
+
# Convert custom_output_names keys to lowercase for case-insensitive comparison
|
| 391 |
+
if custom_output_names:
|
| 392 |
+
custom_output_names_lower = {k.lower(): v for k, v in custom_output_names.items()}
|
| 393 |
+
stem_name_lower = stem_name.lower()
|
| 394 |
+
if stem_name_lower in custom_output_names_lower:
|
| 395 |
+
sanitized_custom_name = self.sanitize_filename(custom_output_names_lower[stem_name_lower])
|
| 396 |
+
return os.path.join(f"{sanitized_custom_name}.{self.output_format.lower()}")
|
| 397 |
+
|
| 398 |
+
sanitized_audio_base = self.sanitize_filename(self.audio_file_base)
|
| 399 |
+
sanitized_stem_name = self.sanitize_filename(stem_name)
|
| 400 |
+
sanitized_model_name = self.sanitize_filename(self.model_name)
|
| 401 |
+
|
| 402 |
+
filename = f"{sanitized_audio_base}_({sanitized_stem_name})_{sanitized_model_name}.{self.output_format.lower()}"
|
| 403 |
+
return os.path.join(filename)
|
audio_separator/separator/separator.py
ADDED
|
@@ -0,0 +1,959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" This file contains the Separator class, to facilitate the separation of stems from audio. """
|
| 2 |
+
|
| 3 |
+
from importlib import metadata, resources
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import platform
|
| 7 |
+
import subprocess
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
+
import warnings
|
| 11 |
+
import importlib
|
| 12 |
+
import io
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import hashlib
|
| 16 |
+
import json
|
| 17 |
+
import yaml
|
| 18 |
+
import requests
|
| 19 |
+
import torch
|
| 20 |
+
import torch.amp.autocast_mode as autocast_mode
|
| 21 |
+
import onnxruntime as ort
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Separator:
|
| 26 |
+
"""
|
| 27 |
+
The Separator class is designed to facilitate the separation of audio sources from a given audio file.
|
| 28 |
+
It supports various separation architectures and models, including MDX, VR, and Demucs. The class provides
|
| 29 |
+
functionalities to configure separation parameters, load models, and perform audio source separation.
|
| 30 |
+
It also handles logging, normalization, and output formatting of the separated audio stems.
|
| 31 |
+
|
| 32 |
+
The actual separation task is handled by one of the architecture-specific classes in the `architectures` module;
|
| 33 |
+
this class is responsible for initialising logging, configuring hardware acceleration, loading the model,
|
| 34 |
+
initiating the separation process and passing outputs back to the caller.
|
| 35 |
+
|
| 36 |
+
Common Attributes:
|
| 37 |
+
log_level (int): The logging level.
|
| 38 |
+
log_formatter (logging.Formatter): The logging formatter.
|
| 39 |
+
model_file_dir (str): The directory where model files are stored.
|
| 40 |
+
output_dir (str): The directory where output files will be saved.
|
| 41 |
+
output_format (str): The format of the output audio file.
|
| 42 |
+
output_bitrate (str): The bitrate of the output audio file.
|
| 43 |
+
amplification_threshold (float): The threshold for audio amplification.
|
| 44 |
+
normalization_threshold (float): The threshold for audio normalization.
|
| 45 |
+
output_single_stem (str): Option to output a single stem.
|
| 46 |
+
invert_using_spec (bool): Flag to invert using spectrogram.
|
| 47 |
+
sample_rate (int): The sample rate of the audio.
|
| 48 |
+
use_soundfile (bool): Use soundfile for audio writing, can solve OOM issues.
|
| 49 |
+
use_autocast (bool): Flag to use PyTorch autocast for faster inference.
|
| 50 |
+
|
| 51 |
+
MDX Architecture Specific Attributes:
|
| 52 |
+
hop_length (int): The hop length for STFT.
|
| 53 |
+
segment_size (int): The segment size for processing.
|
| 54 |
+
overlap (float): The overlap between segments.
|
| 55 |
+
batch_size (int): The batch size for processing.
|
| 56 |
+
enable_denoise (bool): Flag to enable or disable denoising.
|
| 57 |
+
|
| 58 |
+
VR Architecture Specific Attributes & Defaults:
|
| 59 |
+
batch_size: 16
|
| 60 |
+
window_size: 512
|
| 61 |
+
aggression: 5
|
| 62 |
+
enable_tta: False
|
| 63 |
+
enable_post_process: False
|
| 64 |
+
post_process_threshold: 0.2
|
| 65 |
+
high_end_process: False
|
| 66 |
+
|
| 67 |
+
Demucs Architecture Specific Attributes & Defaults:
|
| 68 |
+
segment_size: "Default"
|
| 69 |
+
shifts: 2
|
| 70 |
+
overlap: 0.25
|
| 71 |
+
segments_enabled: True
|
| 72 |
+
|
| 73 |
+
MDXC Architecture Specific Attributes & Defaults:
|
| 74 |
+
segment_size: 256
|
| 75 |
+
override_model_segment_size: False
|
| 76 |
+
batch_size: 1
|
| 77 |
+
overlap: 8
|
| 78 |
+
pitch_shift: 0
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
log_level=logging.INFO,
|
| 84 |
+
log_formatter=None,
|
| 85 |
+
model_file_dir="/tmp/audio-separator-models/",
|
| 86 |
+
output_dir=None,
|
| 87 |
+
output_format="WAV",
|
| 88 |
+
output_bitrate=None,
|
| 89 |
+
normalization_threshold=0.9,
|
| 90 |
+
amplification_threshold=0.0,
|
| 91 |
+
output_single_stem=None,
|
| 92 |
+
invert_using_spec=False,
|
| 93 |
+
sample_rate=44100,
|
| 94 |
+
use_soundfile=False,
|
| 95 |
+
use_autocast=False,
|
| 96 |
+
use_directml=False,
|
| 97 |
+
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
|
| 98 |
+
vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
|
| 99 |
+
demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
|
| 100 |
+
mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
|
| 101 |
+
info_only=False,
|
| 102 |
+
):
|
| 103 |
+
"""Initialize the separator."""
|
| 104 |
+
self.logger = logging.getLogger(__name__)
|
| 105 |
+
self.logger.setLevel(log_level)
|
| 106 |
+
self.log_level = log_level
|
| 107 |
+
self.log_formatter = log_formatter
|
| 108 |
+
|
| 109 |
+
self.log_handler = logging.StreamHandler()
|
| 110 |
+
|
| 111 |
+
if self.log_formatter is None:
|
| 112 |
+
self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
|
| 113 |
+
|
| 114 |
+
self.log_handler.setFormatter(self.log_formatter)
|
| 115 |
+
|
| 116 |
+
if not self.logger.hasHandlers():
|
| 117 |
+
self.logger.addHandler(self.log_handler)
|
| 118 |
+
|
| 119 |
+
# Filter out noisy warnings from PyTorch for users who don't care about them
|
| 120 |
+
if log_level > logging.DEBUG:
|
| 121 |
+
warnings.filterwarnings("ignore")
|
| 122 |
+
|
| 123 |
+
# Skip initialization logs if info_only is True
|
| 124 |
+
if not info_only:
|
| 125 |
+
package_version = self.get_package_distribution("audio-separator").version
|
| 126 |
+
self.logger.info(f"Separator version {package_version} instantiating with output_dir: {output_dir}, output_format: {output_format}")
|
| 127 |
+
|
| 128 |
+
if output_dir is None:
|
| 129 |
+
output_dir = os.getcwd()
|
| 130 |
+
if not info_only:
|
| 131 |
+
self.logger.info("Output directory not specified. Using current working directory.")
|
| 132 |
+
|
| 133 |
+
self.output_dir = output_dir
|
| 134 |
+
|
| 135 |
+
# Check for environment variable to override model_file_dir
|
| 136 |
+
env_model_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR")
|
| 137 |
+
if env_model_dir:
|
| 138 |
+
self.model_file_dir = env_model_dir
|
| 139 |
+
self.logger.info(f"Using model directory from AUDIO_SEPARATOR_MODEL_DIR env var: {self.model_file_dir}")
|
| 140 |
+
if not os.path.exists(self.model_file_dir):
|
| 141 |
+
raise FileNotFoundError(f"The specified model directory does not exist: {self.model_file_dir}")
|
| 142 |
+
else:
|
| 143 |
+
self.logger.info(f"Using model directory from model_file_dir parameter: {model_file_dir}")
|
| 144 |
+
self.model_file_dir = model_file_dir
|
| 145 |
+
|
| 146 |
+
# Create the model directory if it does not exist
|
| 147 |
+
os.makedirs(self.model_file_dir, exist_ok=True)
|
| 148 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 149 |
+
|
| 150 |
+
self.output_format = output_format
|
| 151 |
+
self.output_bitrate = output_bitrate
|
| 152 |
+
|
| 153 |
+
if self.output_format is None:
|
| 154 |
+
self.output_format = "WAV"
|
| 155 |
+
|
| 156 |
+
self.normalization_threshold = normalization_threshold
|
| 157 |
+
if normalization_threshold <= 0 or normalization_threshold > 1:
|
| 158 |
+
raise ValueError("The normalization_threshold must be greater than 0 and less than or equal to 1.")
|
| 159 |
+
|
| 160 |
+
self.amplification_threshold = amplification_threshold
|
| 161 |
+
if amplification_threshold < 0 or amplification_threshold > 1:
|
| 162 |
+
raise ValueError("The amplification_threshold must be greater than or equal to 0 and less than or equal to 1.")
|
| 163 |
+
|
| 164 |
+
self.output_single_stem = output_single_stem
|
| 165 |
+
if output_single_stem is not None:
|
| 166 |
+
self.logger.debug(f"Single stem output requested, so only one output file ({output_single_stem}) will be written")
|
| 167 |
+
|
| 168 |
+
self.invert_using_spec = invert_using_spec
|
| 169 |
+
if self.invert_using_spec:
|
| 170 |
+
self.logger.debug(f"Secondary step will be inverted using spectogram rather than waveform. This may improve quality but is slightly slower.")
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
self.sample_rate = int(sample_rate)
|
| 174 |
+
if self.sample_rate <= 0:
|
| 175 |
+
raise ValueError(f"The sample rate setting is {self.sample_rate} but it must be a non-zero whole number.")
|
| 176 |
+
if self.sample_rate > 12800000:
|
| 177 |
+
raise ValueError(f"The sample rate setting is {self.sample_rate}. Enter something less ambitious.")
|
| 178 |
+
except ValueError:
|
| 179 |
+
raise ValueError("The sample rate must be a non-zero whole number. Please provide a valid integer.")
|
| 180 |
+
|
| 181 |
+
self.use_soundfile = use_soundfile
|
| 182 |
+
self.use_autocast = use_autocast
|
| 183 |
+
self.use_directml = use_directml
|
| 184 |
+
|
| 185 |
+
# These are parameters which users may want to configure so we expose them to the top-level Separator class,
|
| 186 |
+
# even though they are specific to a single model architecture
|
| 187 |
+
self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
|
| 188 |
+
|
| 189 |
+
self.torch_device = None
|
| 190 |
+
self.torch_device_cpu = None
|
| 191 |
+
self.torch_device_mps = None
|
| 192 |
+
|
| 193 |
+
self.onnx_execution_provider = None
|
| 194 |
+
self.model_instance = None
|
| 195 |
+
|
| 196 |
+
self.model_is_uvr_vip = False
|
| 197 |
+
self.model_friendly_name = None
|
| 198 |
+
|
| 199 |
+
if not info_only:
|
| 200 |
+
self.setup_accelerated_inferencing_device()
|
| 201 |
+
|
| 202 |
+
def setup_accelerated_inferencing_device(self):
|
| 203 |
+
"""
|
| 204 |
+
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
|
| 205 |
+
"""
|
| 206 |
+
system_info = self.get_system_info()
|
| 207 |
+
self.check_ffmpeg_installed()
|
| 208 |
+
self.log_onnxruntime_packages()
|
| 209 |
+
self.setup_torch_device(system_info)
|
| 210 |
+
|
| 211 |
+
def get_system_info(self):
|
| 212 |
+
"""
|
| 213 |
+
This method logs the system information, including the operating system, CPU archutecture and Python version
|
| 214 |
+
"""
|
| 215 |
+
os_name = platform.system()
|
| 216 |
+
os_version = platform.version()
|
| 217 |
+
self.logger.info(f"Operating System: {os_name} {os_version}")
|
| 218 |
+
|
| 219 |
+
system_info = platform.uname()
|
| 220 |
+
self.logger.info(f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}")
|
| 221 |
+
|
| 222 |
+
python_version = platform.python_version()
|
| 223 |
+
self.logger.info(f"Python Version: {python_version}")
|
| 224 |
+
|
| 225 |
+
pytorch_version = torch.__version__
|
| 226 |
+
self.logger.info(f"PyTorch Version: {pytorch_version}")
|
| 227 |
+
return system_info
|
| 228 |
+
|
| 229 |
+
def check_ffmpeg_installed(self):
|
| 230 |
+
"""
|
| 231 |
+
This method checks if ffmpeg is installed and logs its version.
|
| 232 |
+
"""
|
| 233 |
+
try:
|
| 234 |
+
ffmpeg_version_output = subprocess.check_output(["ffmpeg", "-version"], text=True)
|
| 235 |
+
first_line = ffmpeg_version_output.splitlines()[0]
|
| 236 |
+
self.logger.info(f"FFmpeg installed: {first_line}")
|
| 237 |
+
except FileNotFoundError:
|
| 238 |
+
self.logger.error("FFmpeg is not installed. Please install FFmpeg to use this package.")
|
| 239 |
+
# Raise an exception if this is being run by a user, as ffmpeg is required for pydub to write audio
|
| 240 |
+
# but if we're just running unit tests in CI, no reason to throw
|
| 241 |
+
if "PYTEST_CURRENT_TEST" not in os.environ:
|
| 242 |
+
raise
|
| 243 |
+
|
| 244 |
+
def log_onnxruntime_packages(self):
|
| 245 |
+
"""
|
| 246 |
+
This method logs the ONNX Runtime package versions, including the GPU and Silicon packages if available.
|
| 247 |
+
"""
|
| 248 |
+
onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
|
| 249 |
+
onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon")
|
| 250 |
+
onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
|
| 251 |
+
onnxruntime_dml_package = self.get_package_distribution("onnxruntime-directml")
|
| 252 |
+
|
| 253 |
+
if onnxruntime_gpu_package is not None:
|
| 254 |
+
self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}")
|
| 255 |
+
if onnxruntime_silicon_package is not None:
|
| 256 |
+
self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}")
|
| 257 |
+
if onnxruntime_cpu_package is not None:
|
| 258 |
+
self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}")
|
| 259 |
+
if onnxruntime_dml_package is not None:
|
| 260 |
+
self.logger.info(f"ONNX Runtime DirectML package installed with version: {onnxruntime_dml_package.version}")
|
| 261 |
+
|
| 262 |
+
def setup_torch_device(self, system_info):
|
| 263 |
+
"""
|
| 264 |
+
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
|
| 265 |
+
"""
|
| 266 |
+
hardware_acceleration_enabled = False
|
| 267 |
+
ort_providers = ort.get_available_providers()
|
| 268 |
+
has_torch_dml_installed = self.get_package_distribution("torch_directml")
|
| 269 |
+
|
| 270 |
+
self.torch_device_cpu = torch.device("cpu")
|
| 271 |
+
|
| 272 |
+
if torch.cuda.is_available():
|
| 273 |
+
self.configure_cuda(ort_providers)
|
| 274 |
+
hardware_acceleration_enabled = True
|
| 275 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
|
| 276 |
+
self.configure_mps(ort_providers)
|
| 277 |
+
hardware_acceleration_enabled = True
|
| 278 |
+
elif self.use_directml and has_torch_dml_installed:
|
| 279 |
+
import torch_directml
|
| 280 |
+
if torch_directml.is_available():
|
| 281 |
+
self.configure_dml(ort_providers)
|
| 282 |
+
hardware_acceleration_enabled = True
|
| 283 |
+
|
| 284 |
+
if not hardware_acceleration_enabled:
|
| 285 |
+
self.logger.info("No hardware acceleration could be configured, running in CPU mode")
|
| 286 |
+
self.torch_device = self.torch_device_cpu
|
| 287 |
+
self.onnx_execution_provider = ["CPUExecutionProvider"]
|
| 288 |
+
|
| 289 |
+
def configure_cuda(self, ort_providers):
|
| 290 |
+
"""
|
| 291 |
+
This method configures the CUDA device for PyTorch and ONNX Runtime, if available.
|
| 292 |
+
"""
|
| 293 |
+
self.logger.info("CUDA is available in Torch, setting Torch device to CUDA")
|
| 294 |
+
self.torch_device = torch.device("cuda")
|
| 295 |
+
if "CUDAExecutionProvider" in ort_providers:
|
| 296 |
+
self.logger.info("ONNXruntime has CUDAExecutionProvider available, enabling acceleration")
|
| 297 |
+
self.onnx_execution_provider = ["CUDAExecutionProvider"]
|
| 298 |
+
else:
|
| 299 |
+
self.logger.warning("CUDAExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
| 300 |
+
|
| 301 |
+
def configure_mps(self, ort_providers):
|
| 302 |
+
"""
|
| 303 |
+
This method configures the Apple Silicon MPS/CoreML device for PyTorch and ONNX Runtime, if available.
|
| 304 |
+
"""
|
| 305 |
+
self.logger.info("Apple Silicon MPS/CoreML is available in Torch and processor is ARM, setting Torch device to MPS")
|
| 306 |
+
self.torch_device_mps = torch.device("mps")
|
| 307 |
+
|
| 308 |
+
self.torch_device = self.torch_device_mps
|
| 309 |
+
|
| 310 |
+
if "CoreMLExecutionProvider" in ort_providers:
|
| 311 |
+
self.logger.info("ONNXruntime has CoreMLExecutionProvider available, enabling acceleration")
|
| 312 |
+
self.onnx_execution_provider = ["CoreMLExecutionProvider"]
|
| 313 |
+
else:
|
| 314 |
+
self.logger.warning("CoreMLExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
| 315 |
+
|
| 316 |
+
def configure_dml(self, ort_providers):
|
| 317 |
+
"""
|
| 318 |
+
This method configures the DirectML device for PyTorch and ONNX Runtime, if available.
|
| 319 |
+
"""
|
| 320 |
+
import torch_directml
|
| 321 |
+
self.logger.info("DirectML is available in Torch, setting Torch device to DirectML")
|
| 322 |
+
self.torch_device_dml = torch_directml.device()
|
| 323 |
+
self.torch_device = self.torch_device_dml
|
| 324 |
+
|
| 325 |
+
if "DmlExecutionProvider" in ort_providers:
|
| 326 |
+
self.logger.info("ONNXruntime has DmlExecutionProvider available, enabling acceleration")
|
| 327 |
+
self.onnx_execution_provider = ["DmlExecutionProvider"]
|
| 328 |
+
else:
|
| 329 |
+
self.logger.warning("DmlExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
| 330 |
+
|
| 331 |
+
def get_package_distribution(self, package_name):
|
| 332 |
+
"""
|
| 333 |
+
This method returns the package distribution for a given package name if installed, or None otherwise.
|
| 334 |
+
"""
|
| 335 |
+
try:
|
| 336 |
+
return metadata.distribution(package_name)
|
| 337 |
+
except metadata.PackageNotFoundError:
|
| 338 |
+
self.logger.debug(f"Python package: {package_name} not installed")
|
| 339 |
+
return None
|
| 340 |
+
|
| 341 |
+
def get_model_hash(self, model_path):
|
| 342 |
+
"""
|
| 343 |
+
This method returns the MD5 hash of a given model file.
|
| 344 |
+
"""
|
| 345 |
+
self.logger.debug(f"Calculating hash of model file {model_path}")
|
| 346 |
+
# Use the specific byte count from the original logic
|
| 347 |
+
BYTES_TO_HASH = 10000 * 1024 # 10,240,000 bytes
|
| 348 |
+
|
| 349 |
+
try:
|
| 350 |
+
file_size = os.path.getsize(model_path)
|
| 351 |
+
|
| 352 |
+
with open(model_path, "rb") as f:
|
| 353 |
+
if file_size < BYTES_TO_HASH:
|
| 354 |
+
# Hash the entire file if smaller than the target byte count
|
| 355 |
+
self.logger.debug(f"File size {file_size} < {BYTES_TO_HASH}, hashing entire file.")
|
| 356 |
+
hash_value = hashlib.md5(f.read()).hexdigest()
|
| 357 |
+
else:
|
| 358 |
+
# Seek to the specific position before the end (from the beginning) and hash
|
| 359 |
+
seek_pos = file_size - BYTES_TO_HASH
|
| 360 |
+
self.logger.debug(f"File size {file_size} >= {BYTES_TO_HASH}, seeking to {seek_pos} and hashing remaining bytes.")
|
| 361 |
+
f.seek(seek_pos, io.SEEK_SET)
|
| 362 |
+
hash_value = hashlib.md5(f.read()).hexdigest()
|
| 363 |
+
|
| 364 |
+
# Log the calculated hash
|
| 365 |
+
self.logger.info(f"Hash of model file {model_path} is {hash_value}")
|
| 366 |
+
return hash_value
|
| 367 |
+
|
| 368 |
+
except FileNotFoundError:
|
| 369 |
+
self.logger.error(f"Model file not found at {model_path}")
|
| 370 |
+
raise # Re-raise the specific error
|
| 371 |
+
except Exception as e:
|
| 372 |
+
# Catch other potential errors (e.g., permissions, other IOErrors)
|
| 373 |
+
self.logger.error(f"Error calculating hash for {model_path}: {e}")
|
| 374 |
+
raise # Re-raise other errors
|
| 375 |
+
|
| 376 |
+
def download_file_if_not_exists(self, url, output_path):
|
| 377 |
+
"""
|
| 378 |
+
This method downloads a file from a given URL to a given output path, if the file does not already exist.
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
if os.path.isfile(output_path):
|
| 382 |
+
self.logger.debug(f"File already exists at {output_path}, skipping download")
|
| 383 |
+
return
|
| 384 |
+
|
| 385 |
+
self.logger.debug(f"Downloading file from {url} to {output_path} with timeout 300s")
|
| 386 |
+
response = requests.get(url, stream=True, timeout=300)
|
| 387 |
+
|
| 388 |
+
if response.status_code == 200:
|
| 389 |
+
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
| 390 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
| 391 |
+
|
| 392 |
+
with open(output_path, "wb") as f:
|
| 393 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 394 |
+
progress_bar.update(len(chunk))
|
| 395 |
+
f.write(chunk)
|
| 396 |
+
progress_bar.close()
|
| 397 |
+
else:
|
| 398 |
+
raise RuntimeError(f"Failed to download file from {url}, response code: {response.status_code}")
|
| 399 |
+
|
| 400 |
+
def list_supported_model_files(self):
|
| 401 |
+
"""
|
| 402 |
+
This method lists the supported model files for audio-separator, by fetching the same file UVR uses to list these.
|
| 403 |
+
Also includes model performance scores where available.
|
| 404 |
+
|
| 405 |
+
Example response object:
|
| 406 |
+
|
| 407 |
+
{
|
| 408 |
+
"MDX": {
|
| 409 |
+
"MDX-Net Model VIP: UVR-MDX-NET-Inst_full_292": {
|
| 410 |
+
"filename": "UVR-MDX-NET-Inst_full_292.onnx",
|
| 411 |
+
"scores": {
|
| 412 |
+
"vocals": {
|
| 413 |
+
"SDR": 10.6497,
|
| 414 |
+
"SIR": 20.3786,
|
| 415 |
+
"SAR": 10.692,
|
| 416 |
+
"ISR": 14.848
|
| 417 |
+
},
|
| 418 |
+
"instrumental": {
|
| 419 |
+
"SDR": 15.2149,
|
| 420 |
+
"SIR": 25.6075,
|
| 421 |
+
"SAR": 17.1363,
|
| 422 |
+
"ISR": 17.7893
|
| 423 |
+
}
|
| 424 |
+
},
|
| 425 |
+
"download_files": [
|
| 426 |
+
"UVR-MDX-NET-Inst_full_292.onnx"
|
| 427 |
+
]
|
| 428 |
+
}
|
| 429 |
+
},
|
| 430 |
+
"Demucs": {
|
| 431 |
+
"Demucs v4: htdemucs_ft": {
|
| 432 |
+
"filename": "htdemucs_ft.yaml",
|
| 433 |
+
"scores": {
|
| 434 |
+
"vocals": {
|
| 435 |
+
"SDR": 11.2685,
|
| 436 |
+
"SIR": 21.257,
|
| 437 |
+
"SAR": 11.0359,
|
| 438 |
+
"ISR": 19.3753
|
| 439 |
+
},
|
| 440 |
+
"drums": {
|
| 441 |
+
"SDR": 13.235,
|
| 442 |
+
"SIR": 23.3053,
|
| 443 |
+
"SAR": 13.0313,
|
| 444 |
+
"ISR": 17.2889
|
| 445 |
+
},
|
| 446 |
+
"bass": {
|
| 447 |
+
"SDR": 9.72743,
|
| 448 |
+
"SIR": 19.5435,
|
| 449 |
+
"SAR": 9.20801,
|
| 450 |
+
"ISR": 13.5037
|
| 451 |
+
}
|
| 452 |
+
},
|
| 453 |
+
"download_files": [
|
| 454 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th",
|
| 455 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th",
|
| 456 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th",
|
| 457 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th",
|
| 458 |
+
"https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml"
|
| 459 |
+
]
|
| 460 |
+
}
|
| 461 |
+
},
|
| 462 |
+
"MDXC": {
|
| 463 |
+
"MDX23C Model: MDX23C-InstVoc HQ": {
|
| 464 |
+
"filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt",
|
| 465 |
+
"scores": {
|
| 466 |
+
"vocals": {
|
| 467 |
+
"SDR": 11.9504,
|
| 468 |
+
"SIR": 23.1166,
|
| 469 |
+
"SAR": 12.093,
|
| 470 |
+
"ISR": 15.4782
|
| 471 |
+
},
|
| 472 |
+
"instrumental": {
|
| 473 |
+
"SDR": 16.3035,
|
| 474 |
+
"SIR": 26.6161,
|
| 475 |
+
"SAR": 18.5167,
|
| 476 |
+
"ISR": 18.3939
|
| 477 |
+
}
|
| 478 |
+
},
|
| 479 |
+
"download_files": [
|
| 480 |
+
"MDX23C-8KFFT-InstVoc_HQ.ckpt",
|
| 481 |
+
"model_2_stem_full_band_8k.yaml"
|
| 482 |
+
]
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
"""
|
| 487 |
+
download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
|
| 488 |
+
|
| 489 |
+
self.download_file_if_not_exists("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json", download_checks_path)
|
| 490 |
+
|
| 491 |
+
model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
|
| 492 |
+
self.logger.debug(f"UVR model download list loaded")
|
| 493 |
+
|
| 494 |
+
# Load the model scores with error handling
|
| 495 |
+
model_scores = {}
|
| 496 |
+
try:
|
| 497 |
+
with resources.open_text("audio_separator", "models-scores.json") as f:
|
| 498 |
+
model_scores = json.load(f)
|
| 499 |
+
self.logger.debug(f"Model scores loaded")
|
| 500 |
+
except json.JSONDecodeError as e:
|
| 501 |
+
self.logger.warning(f"Failed to load model scores: {str(e)}")
|
| 502 |
+
self.logger.warning("Continuing without model scores")
|
| 503 |
+
|
| 504 |
+
# Only show Demucs v4 models as we've only implemented support for v4
|
| 505 |
+
filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}
|
| 506 |
+
|
| 507 |
+
# Modified Demucs handling to use YAML files as identifiers and include download files
|
| 508 |
+
demucs_models = {}
|
| 509 |
+
for name, files in filtered_demucs_v4.items():
|
| 510 |
+
# Find the YAML file in the model files
|
| 511 |
+
yaml_file = next((filename for filename in files.keys() if filename.endswith(".yaml")), None)
|
| 512 |
+
if yaml_file:
|
| 513 |
+
model_score_data = model_scores.get(yaml_file, {})
|
| 514 |
+
demucs_models[name] = {
|
| 515 |
+
"filename": yaml_file,
|
| 516 |
+
"scores": model_score_data.get("median_scores", {}),
|
| 517 |
+
"stems": model_score_data.get("stems", []),
|
| 518 |
+
"target_stem": model_score_data.get("target_stem"),
|
| 519 |
+
"download_files": list(files.values()), # List of all download URLs/filenames
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
# Load the JSON file using importlib.resources
|
| 523 |
+
with resources.open_text("audio_separator", "models.json") as f:
|
| 524 |
+
audio_separator_models_list = json.load(f)
|
| 525 |
+
self.logger.debug(f"Audio-Separator model list loaded")
|
| 526 |
+
|
| 527 |
+
# Return object with list of model names
|
| 528 |
+
model_files_grouped_by_type = {
|
| 529 |
+
"VR": {
|
| 530 |
+
name: {
|
| 531 |
+
"filename": filename,
|
| 532 |
+
"scores": model_scores.get(filename, {}).get("median_scores", {}),
|
| 533 |
+
"stems": model_scores.get(filename, {}).get("stems", []),
|
| 534 |
+
"target_stem": model_scores.get(filename, {}).get("target_stem"),
|
| 535 |
+
"download_files": [filename],
|
| 536 |
+
} # Just the filename for VR models
|
| 537 |
+
for name, filename in {**model_downloads_list["vr_download_list"], **audio_separator_models_list["vr_download_list"]}.items()
|
| 538 |
+
},
|
| 539 |
+
"MDX": {
|
| 540 |
+
name: {
|
| 541 |
+
"filename": filename,
|
| 542 |
+
"scores": model_scores.get(filename, {}).get("median_scores", {}),
|
| 543 |
+
"stems": model_scores.get(filename, {}).get("stems", []),
|
| 544 |
+
"target_stem": model_scores.get(filename, {}).get("target_stem"),
|
| 545 |
+
"download_files": [filename],
|
| 546 |
+
} # Just the filename for MDX models
|
| 547 |
+
for name, filename in {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"], **audio_separator_models_list["mdx_download_list"]}.items()
|
| 548 |
+
},
|
| 549 |
+
"Demucs": demucs_models,
|
| 550 |
+
"MDXC": {
|
| 551 |
+
name: {
|
| 552 |
+
"filename": next(iter(files.keys())),
|
| 553 |
+
"scores": model_scores.get(next(iter(files.keys())), {}).get("median_scores", {}),
|
| 554 |
+
"stems": model_scores.get(next(iter(files.keys())), {}).get("stems", []),
|
| 555 |
+
"target_stem": model_scores.get(next(iter(files.keys())), {}).get("target_stem"),
|
| 556 |
+
"download_files": list(files.keys()) + list(files.values()), # List of both model filenames and config filenames
|
| 557 |
+
}
|
| 558 |
+
for name, files in {
|
| 559 |
+
**model_downloads_list["mdx23c_download_list"],
|
| 560 |
+
**model_downloads_list["mdx23c_download_vip_list"],
|
| 561 |
+
**model_downloads_list["roformer_download_list"],
|
| 562 |
+
**audio_separator_models_list["mdx23c_download_list"],
|
| 563 |
+
**audio_separator_models_list["roformer_download_list"],
|
| 564 |
+
}.items()
|
| 565 |
+
},
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
return model_files_grouped_by_type
|
| 569 |
+
|
| 570 |
+
def print_uvr_vip_message(self):
|
| 571 |
+
"""
|
| 572 |
+
This method prints a message to the user if they have downloaded a VIP model, reminding them to support Anjok07 on Patreon.
|
| 573 |
+
"""
|
| 574 |
+
if self.model_is_uvr_vip:
|
| 575 |
+
self.logger.warning(f"The model: '{self.model_friendly_name}' is a VIP model, intended by Anjok07 for access by paying subscribers only.")
|
| 576 |
+
self.logger.warning("If you are not already subscribed, please consider supporting the developer of UVR, Anjok07 by subscribing here: https://patreon.com/uvr")
|
| 577 |
+
|
| 578 |
+
def download_model_files(self, model_filename):
|
| 579 |
+
"""
|
| 580 |
+
This method downloads the model files for a given model filename, if they are not already present.
|
| 581 |
+
Returns tuple of (model_filename, model_type, model_friendly_name, model_path, yaml_config_filename)
|
| 582 |
+
"""
|
| 583 |
+
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
|
| 584 |
+
|
| 585 |
+
supported_model_files_grouped = self.list_supported_model_files()
|
| 586 |
+
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
|
| 587 |
+
vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
|
| 588 |
+
audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
|
| 589 |
+
|
| 590 |
+
yaml_config_filename = None
|
| 591 |
+
|
| 592 |
+
self.logger.debug(f"Searching for model_filename {model_filename} in supported_model_files_grouped")
|
| 593 |
+
|
| 594 |
+
# Iterate through model types (MDX, Demucs, MDXC)
|
| 595 |
+
for model_type, models in supported_model_files_grouped.items():
|
| 596 |
+
# Iterate through each model in this type
|
| 597 |
+
for model_friendly_name, model_info in models.items():
|
| 598 |
+
self.model_is_uvr_vip = "VIP" in model_friendly_name
|
| 599 |
+
model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
|
| 600 |
+
|
| 601 |
+
# Check if this model matches our target filename
|
| 602 |
+
if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
|
| 603 |
+
self.logger.debug(f"Found matching model: {model_friendly_name}")
|
| 604 |
+
self.model_friendly_name = model_friendly_name
|
| 605 |
+
self.print_uvr_vip_message()
|
| 606 |
+
|
| 607 |
+
# Download each required file for this model
|
| 608 |
+
for file_to_download in model_info["download_files"]:
|
| 609 |
+
# For URLs, extract just the filename portion
|
| 610 |
+
if file_to_download.startswith("http"):
|
| 611 |
+
filename = file_to_download.split("/")[-1]
|
| 612 |
+
download_path = os.path.join(self.model_file_dir, filename)
|
| 613 |
+
self.download_file_if_not_exists(file_to_download, download_path)
|
| 614 |
+
continue
|
| 615 |
+
|
| 616 |
+
download_path = os.path.join(self.model_file_dir, file_to_download)
|
| 617 |
+
|
| 618 |
+
# For MDXC models, handle YAML config files specially
|
| 619 |
+
if model_type == "MDXC" and file_to_download.endswith(".yaml"):
|
| 620 |
+
yaml_config_filename = file_to_download
|
| 621 |
+
try:
|
| 622 |
+
yaml_url = f"{model_repo_url_prefix}/mdx_model_data/mdx_c_configs/{file_to_download}"
|
| 623 |
+
self.download_file_if_not_exists(yaml_url, download_path)
|
| 624 |
+
except RuntimeError:
|
| 625 |
+
self.logger.debug("YAML config not found in UVR repo, trying audio-separator models repo...")
|
| 626 |
+
yaml_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
|
| 627 |
+
self.download_file_if_not_exists(yaml_url, download_path)
|
| 628 |
+
continue
|
| 629 |
+
|
| 630 |
+
# For regular model files, try UVR repo first, then audio-separator repo
|
| 631 |
+
try:
|
| 632 |
+
download_url = f"{model_repo_url_prefix}/{file_to_download}"
|
| 633 |
+
self.download_file_if_not_exists(download_url, download_path)
|
| 634 |
+
except RuntimeError:
|
| 635 |
+
self.logger.debug("Model not found in UVR repo, trying audio-separator models repo...")
|
| 636 |
+
download_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
|
| 637 |
+
self.download_file_if_not_exists(download_url, download_path)
|
| 638 |
+
|
| 639 |
+
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
| 640 |
+
|
| 641 |
+
raise ValueError(f"Model file {model_filename} not found in supported model files")
|
| 642 |
+
|
| 643 |
+
def load_model_data_from_yaml(self, yaml_config_filename):
|
| 644 |
+
"""
|
| 645 |
+
This method loads model-specific parameters from the YAML file for that model.
|
| 646 |
+
The parameters in the YAML are critical to inferencing, as they need to match whatever was used during training.
|
| 647 |
+
"""
|
| 648 |
+
# Verify if the YAML filename includes a full path or just the filename
|
| 649 |
+
if not os.path.exists(yaml_config_filename):
|
| 650 |
+
model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
|
| 651 |
+
else:
|
| 652 |
+
model_data_yaml_filepath = yaml_config_filename
|
| 653 |
+
|
| 654 |
+
self.logger.debug(f"Loading model data from YAML at path {model_data_yaml_filepath}")
|
| 655 |
+
|
| 656 |
+
model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
|
| 657 |
+
self.logger.debug(f"Model data loaded from YAML file: {model_data}")
|
| 658 |
+
|
| 659 |
+
if "roformer" in model_data_yaml_filepath:
|
| 660 |
+
model_data["is_roformer"] = True
|
| 661 |
+
|
| 662 |
+
return model_data
|
| 663 |
+
|
| 664 |
+
def load_model_data_using_hash(self, model_path):
|
| 665 |
+
"""
|
| 666 |
+
This method loads model-specific parameters from UVR model data files.
|
| 667 |
+
These parameters are critical to inferencing using a given model, as they need to match whatever was used during training.
|
| 668 |
+
The correct parameters are identified by calculating the hash of the model file and looking up the hash in the UVR data files.
|
| 669 |
+
"""
|
| 670 |
+
# Model data and configuration sources from UVR
|
| 671 |
+
model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
|
| 672 |
+
|
| 673 |
+
vr_model_data_url = f"{model_data_url_prefix}/vr_model_data/model_data_new.json"
|
| 674 |
+
mdx_model_data_url = f"{model_data_url_prefix}/mdx_model_data/model_data_new.json"
|
| 675 |
+
|
| 676 |
+
# Calculate hash for the downloaded model
|
| 677 |
+
self.logger.debug("Calculating MD5 hash for model file to identify model parameters from UVR data...")
|
| 678 |
+
model_hash = self.get_model_hash(model_path)
|
| 679 |
+
self.logger.debug(f"Model {model_path} has hash {model_hash}")
|
| 680 |
+
|
| 681 |
+
# Setting up the path for model data and checking its existence
|
| 682 |
+
vr_model_data_path = os.path.join(self.model_file_dir, "vr_model_data.json")
|
| 683 |
+
self.logger.debug(f"VR model data path set to {vr_model_data_path}")
|
| 684 |
+
self.download_file_if_not_exists(vr_model_data_url, vr_model_data_path)
|
| 685 |
+
|
| 686 |
+
mdx_model_data_path = os.path.join(self.model_file_dir, "mdx_model_data.json")
|
| 687 |
+
self.logger.debug(f"MDX model data path set to {mdx_model_data_path}")
|
| 688 |
+
self.download_file_if_not_exists(mdx_model_data_url, mdx_model_data_path)
|
| 689 |
+
|
| 690 |
+
# Loading model data from UVR
|
| 691 |
+
self.logger.debug("Loading MDX and VR model parameters from UVR model data files...")
|
| 692 |
+
vr_model_data_object = json.load(open(vr_model_data_path, encoding="utf-8"))
|
| 693 |
+
mdx_model_data_object = json.load(open(mdx_model_data_path, encoding="utf-8"))
|
| 694 |
+
|
| 695 |
+
# Load additional model data from audio-separator
|
| 696 |
+
self.logger.debug("Loading additional model parameters from audio-separator model data file...")
|
| 697 |
+
with resources.open_text("audio_separator", "model-data.json") as f:
|
| 698 |
+
audio_separator_model_data = json.load(f)
|
| 699 |
+
|
| 700 |
+
# Merge the model data objects, with audio-separator data taking precedence
|
| 701 |
+
vr_model_data_object = {**vr_model_data_object, **audio_separator_model_data.get("vr_model_data", {})}
|
| 702 |
+
mdx_model_data_object = {**mdx_model_data_object, **audio_separator_model_data.get("mdx_model_data", {})}
|
| 703 |
+
|
| 704 |
+
if model_hash in mdx_model_data_object:
|
| 705 |
+
model_data = mdx_model_data_object[model_hash]
|
| 706 |
+
elif model_hash in vr_model_data_object:
|
| 707 |
+
model_data = vr_model_data_object[model_hash]
|
| 708 |
+
else:
|
| 709 |
+
raise ValueError(f"Unsupported Model File: parameters for MD5 hash {model_hash} could not be found in UVR model data file for MDX or VR arch.")
|
| 710 |
+
|
| 711 |
+
self.logger.debug(f"Model data loaded using hash {model_hash}: {model_data}")
|
| 712 |
+
|
| 713 |
+
return model_data
|
| 714 |
+
|
| 715 |
+
def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"):
|
| 716 |
+
"""
|
| 717 |
+
This method instantiates the architecture-specific separation class,
|
| 718 |
+
loading the separation model into memory, downloading it first if necessary.
|
| 719 |
+
"""
|
| 720 |
+
self.logger.info(f"Loading model {model_filename}...")
|
| 721 |
+
|
| 722 |
+
load_model_start_time = time.perf_counter()
|
| 723 |
+
|
| 724 |
+
# Setting up the model path
|
| 725 |
+
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 726 |
+
model_name = model_filename.split(".")[0]
|
| 727 |
+
self.logger.debug(f"Model downloaded, friendly name: {model_friendly_name}, model_path: {model_path}")
|
| 728 |
+
|
| 729 |
+
if model_path.lower().endswith(".yaml"):
|
| 730 |
+
yaml_config_filename = model_path
|
| 731 |
+
|
| 732 |
+
if yaml_config_filename is not None:
|
| 733 |
+
model_data = self.load_model_data_from_yaml(yaml_config_filename)
|
| 734 |
+
else:
|
| 735 |
+
model_data = self.load_model_data_using_hash(model_path)
|
| 736 |
+
|
| 737 |
+
common_params = {
|
| 738 |
+
"logger": self.logger,
|
| 739 |
+
"log_level": self.log_level,
|
| 740 |
+
"torch_device": self.torch_device,
|
| 741 |
+
"torch_device_cpu": self.torch_device_cpu,
|
| 742 |
+
"torch_device_mps": self.torch_device_mps,
|
| 743 |
+
"onnx_execution_provider": self.onnx_execution_provider,
|
| 744 |
+
"model_name": model_name,
|
| 745 |
+
"model_path": model_path,
|
| 746 |
+
"model_data": model_data,
|
| 747 |
+
"output_format": self.output_format,
|
| 748 |
+
"output_bitrate": self.output_bitrate,
|
| 749 |
+
"output_dir": self.output_dir,
|
| 750 |
+
"normalization_threshold": self.normalization_threshold,
|
| 751 |
+
"amplification_threshold": self.amplification_threshold,
|
| 752 |
+
"output_single_stem": self.output_single_stem,
|
| 753 |
+
"invert_using_spec": self.invert_using_spec,
|
| 754 |
+
"sample_rate": self.sample_rate,
|
| 755 |
+
"use_soundfile": self.use_soundfile,
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
# Instantiate the appropriate separator class depending on the model type
|
| 759 |
+
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"}
|
| 760 |
+
|
| 761 |
+
if model_type not in self.arch_specific_params or model_type not in separator_classes:
|
| 762 |
+
raise ValueError(f"Model type not supported (yet): {model_type}")
|
| 763 |
+
|
| 764 |
+
if model_type == "Demucs" and sys.version_info < (3, 10):
|
| 765 |
+
raise Exception("Demucs models require Python version 3.10 or newer.")
|
| 766 |
+
|
| 767 |
+
self.logger.debug(f"Importing module for model type {model_type}: {separator_classes[model_type]}")
|
| 768 |
+
|
| 769 |
+
module_name, class_name = separator_classes[model_type].split(".")
|
| 770 |
+
module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}")
|
| 771 |
+
separator_class = getattr(module, class_name)
|
| 772 |
+
|
| 773 |
+
self.logger.debug(f"Instantiating separator class for model type {model_type}: {separator_class}")
|
| 774 |
+
self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
|
| 775 |
+
|
| 776 |
+
# Log the completion of the model load process
|
| 777 |
+
self.logger.debug("Loading model completed.")
|
| 778 |
+
self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}')
|
| 779 |
+
|
| 780 |
+
def separate(self, audio_file_path, custom_output_names=None):
|
| 781 |
+
"""
|
| 782 |
+
Separates the audio file(s) into different stems (e.g., vocals, instruments) using the loaded model.
|
| 783 |
+
|
| 784 |
+
This method takes the path to an audio file or a directory containing audio files, processes them through
|
| 785 |
+
the loaded separation model, and returns the paths to the output files containing the separated audio stems.
|
| 786 |
+
It handles the entire flow from loading the audio, running the separation, clearing up resources, and logging the process.
|
| 787 |
+
|
| 788 |
+
Parameters:
|
| 789 |
+
- audio_file_path (str or list): The path to the audio file or directory, or a list of paths.
|
| 790 |
+
- custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
| 791 |
+
|
| 792 |
+
Returns:
|
| 793 |
+
- output_files (list of str): A list containing the paths to the separated audio stem files.
|
| 794 |
+
"""
|
| 795 |
+
# Check if the model and device are properly initialized
|
| 796 |
+
if not (self.torch_device and self.model_instance):
|
| 797 |
+
raise ValueError("Initialization failed or model not loaded. Please load a model before attempting to separate.")
|
| 798 |
+
|
| 799 |
+
# If audio_file_path is a string, convert it to a list for uniform processing
|
| 800 |
+
if isinstance(audio_file_path, str):
|
| 801 |
+
audio_file_path = [audio_file_path]
|
| 802 |
+
|
| 803 |
+
# Initialize a list to store paths of all output files
|
| 804 |
+
output_files = []
|
| 805 |
+
|
| 806 |
+
# Process each path in the list
|
| 807 |
+
for path in audio_file_path:
|
| 808 |
+
if os.path.isdir(path):
|
| 809 |
+
# If the path is a directory, recursively search for all audio files
|
| 810 |
+
for root, dirs, files in os.walk(path):
|
| 811 |
+
for file in files:
|
| 812 |
+
# Check the file extension to ensure it's an audio file
|
| 813 |
+
if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed
|
| 814 |
+
full_path = os.path.join(root, file)
|
| 815 |
+
self.logger.info(f"Processing file: {full_path}")
|
| 816 |
+
try:
|
| 817 |
+
# Perform separation for each file
|
| 818 |
+
files_output = self._separate_file(full_path, custom_output_names)
|
| 819 |
+
output_files.extend(files_output)
|
| 820 |
+
except Exception as e:
|
| 821 |
+
self.logger.error(f"Failed to process file {full_path}: {e}")
|
| 822 |
+
else:
|
| 823 |
+
# If the path is a file, process it directly
|
| 824 |
+
self.logger.info(f"Processing file: {path}")
|
| 825 |
+
try:
|
| 826 |
+
files_output = self._separate_file(path, custom_output_names)
|
| 827 |
+
output_files.extend(files_output)
|
| 828 |
+
except Exception as e:
|
| 829 |
+
self.logger.error(f"Failed to process file {path}: {e}")
|
| 830 |
+
|
| 831 |
+
return output_files
|
| 832 |
+
|
| 833 |
+
def _separate_file(self, audio_file_path, custom_output_names=None):
|
| 834 |
+
"""
|
| 835 |
+
Internal method to handle separation for a single audio file.
|
| 836 |
+
This method performs the actual separation process for a single audio file. It logs the start and end of the process,
|
| 837 |
+
handles autocast if enabled, and ensures GPU cache is cleared after processing.
|
| 838 |
+
Parameters:
|
| 839 |
+
- audio_file_path (str): The path to the audio file.
|
| 840 |
+
- custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
| 841 |
+
Returns:
|
| 842 |
+
- output_files (list of str): A list containing the paths to the separated audio stem files.
|
| 843 |
+
"""
|
| 844 |
+
# Log the start of the separation process
|
| 845 |
+
self.logger.info(f"Starting separation process for audio_file_path: {audio_file_path}")
|
| 846 |
+
separate_start_time = time.perf_counter()
|
| 847 |
+
|
| 848 |
+
# Log normalization and amplification thresholds
|
| 849 |
+
self.logger.debug(f"Normalization threshold set to {self.normalization_threshold}, waveform will be lowered to this max amplitude to avoid clipping.")
|
| 850 |
+
self.logger.debug(f"Amplification threshold set to {self.amplification_threshold}, waveform will be scaled up to this max amplitude if below it.")
|
| 851 |
+
|
| 852 |
+
# Run separation method for the loaded model with autocast enabled if supported by the device
|
| 853 |
+
output_files = None
|
| 854 |
+
if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
|
| 855 |
+
self.logger.debug("Autocast available.")
|
| 856 |
+
with autocast_mode.autocast(self.torch_device.type):
|
| 857 |
+
output_files = self.model_instance.separate(audio_file_path, custom_output_names)
|
| 858 |
+
else:
|
| 859 |
+
self.logger.debug("Autocast unavailable.")
|
| 860 |
+
output_files = self.model_instance.separate(audio_file_path, custom_output_names)
|
| 861 |
+
|
| 862 |
+
# Clear GPU cache to free up memory
|
| 863 |
+
self.model_instance.clear_gpu_cache()
|
| 864 |
+
|
| 865 |
+
# Unset separation parameters to prevent accidentally re-using the wrong source files or output paths
|
| 866 |
+
self.model_instance.clear_file_specific_paths()
|
| 867 |
+
|
| 868 |
+
# Remind the user one more time if they used a VIP model, so the message doesn't get lost in the logs
|
| 869 |
+
self.print_uvr_vip_message()
|
| 870 |
+
|
| 871 |
+
# Log the completion of the separation process
|
| 872 |
+
self.logger.debug("Separation process completed.")
|
| 873 |
+
self.logger.info(f'Separation duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - separate_start_time)))}')
|
| 874 |
+
|
| 875 |
+
return output_files
|
| 876 |
+
|
| 877 |
+
def download_model_and_data(self, model_filename):
|
| 878 |
+
"""
|
| 879 |
+
Downloads the model file without loading it into memory.
|
| 880 |
+
"""
|
| 881 |
+
self.logger.info(f"Downloading model {model_filename}...")
|
| 882 |
+
|
| 883 |
+
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 884 |
+
|
| 885 |
+
if model_path.lower().endswith(".yaml"):
|
| 886 |
+
yaml_config_filename = model_path
|
| 887 |
+
|
| 888 |
+
if yaml_config_filename is not None:
|
| 889 |
+
model_data = self.load_model_data_from_yaml(yaml_config_filename)
|
| 890 |
+
else:
|
| 891 |
+
model_data = self.load_model_data_using_hash(model_path)
|
| 892 |
+
|
| 893 |
+
model_data_dict_size = len(model_data)
|
| 894 |
+
|
| 895 |
+
self.logger.info(f"Model downloaded, type: {model_type}, friendly name: {model_friendly_name}, model_path: {model_path}, model_data: {model_data_dict_size} items")
|
| 896 |
+
|
| 897 |
+
def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
|
| 898 |
+
"""
|
| 899 |
+
Returns a simplified, user-friendly list of models with their key metrics.
|
| 900 |
+
Optionally sorts the list based on the specified criteria.
|
| 901 |
+
|
| 902 |
+
:param sort_by: Criteria to sort by. Can be "name", "filename", or any stem name
|
| 903 |
+
"""
|
| 904 |
+
model_files = self.list_supported_model_files()
|
| 905 |
+
simplified_list = {}
|
| 906 |
+
|
| 907 |
+
for model_type, models in model_files.items():
|
| 908 |
+
for name, data in models.items():
|
| 909 |
+
filename = data["filename"]
|
| 910 |
+
scores = data.get("scores") or {}
|
| 911 |
+
stems = data.get("stems") or []
|
| 912 |
+
target_stem = data.get("target_stem")
|
| 913 |
+
|
| 914 |
+
# Format stems with their SDR scores where available
|
| 915 |
+
stems_with_scores = []
|
| 916 |
+
stem_sdr_dict = {}
|
| 917 |
+
|
| 918 |
+
# Process each stem from the model's stem list
|
| 919 |
+
for stem in stems:
|
| 920 |
+
stem_scores = scores.get(stem, {})
|
| 921 |
+
# Add asterisk if this is the target stem
|
| 922 |
+
stem_display = f"{stem}*" if stem == target_stem else stem
|
| 923 |
+
|
| 924 |
+
if isinstance(stem_scores, dict) and "SDR" in stem_scores:
|
| 925 |
+
sdr = round(stem_scores["SDR"], 1)
|
| 926 |
+
stems_with_scores.append(f"{stem_display} ({sdr})")
|
| 927 |
+
stem_sdr_dict[stem.lower()] = sdr
|
| 928 |
+
else:
|
| 929 |
+
# Include stem without SDR score
|
| 930 |
+
stems_with_scores.append(stem_display)
|
| 931 |
+
stem_sdr_dict[stem.lower()] = None
|
| 932 |
+
|
| 933 |
+
# If no stems listed, mark as Unknown
|
| 934 |
+
if not stems_with_scores:
|
| 935 |
+
stems_with_scores = ["Unknown"]
|
| 936 |
+
stem_sdr_dict["unknown"] = None
|
| 937 |
+
|
| 938 |
+
simplified_list[filename] = {"Name": name, "Type": model_type, "Stems": stems_with_scores, "SDR": stem_sdr_dict}
|
| 939 |
+
|
| 940 |
+
# Sort and filter the list if a sort_by parameter is provided
|
| 941 |
+
if filter_sort_by:
|
| 942 |
+
if filter_sort_by == "name":
|
| 943 |
+
return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
|
| 944 |
+
elif filter_sort_by == "filename":
|
| 945 |
+
return dict(sorted(simplified_list.items()))
|
| 946 |
+
else:
|
| 947 |
+
# Convert sort_by to lowercase for case-insensitive comparison
|
| 948 |
+
sort_by_lower = filter_sort_by.lower()
|
| 949 |
+
# Filter out models that don't have the specified stem
|
| 950 |
+
filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
|
| 951 |
+
|
| 952 |
+
# Sort by SDR score if available, putting None values last
|
| 953 |
+
def sort_key(item):
|
| 954 |
+
sdr = item[1]["SDR"][sort_by_lower]
|
| 955 |
+
return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
|
| 956 |
+
|
| 957 |
+
return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
|
| 958 |
+
|
| 959 |
+
return simplified_list
|
audio_separator/separator/uvr_lib_v5/__init__.py
ADDED
|
File without changes
|
audio_separator/separator/uvr_lib_v5/demucs/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
audio_separator/separator/uvr_lib_v5/demucs/__main__.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from fractions import Fraction
|
| 13 |
+
|
| 14 |
+
import torch as th
|
| 15 |
+
from torch import distributed, nn
|
| 16 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
| 17 |
+
|
| 18 |
+
from .augment import FlipChannels, FlipSign, Remix, Shift
|
| 19 |
+
from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks
|
| 20 |
+
from .model import Demucs
|
| 21 |
+
from .parser import get_name, get_parser
|
| 22 |
+
from .raw import Rawset
|
| 23 |
+
from .tasnet import ConvTasNet
|
| 24 |
+
from .test import evaluate
|
| 25 |
+
from .train import train_model, validate_model
|
| 26 |
+
from .utils import human_seconds, load_model, save_model, sizeof_fmt
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class SavedState:
|
| 31 |
+
metrics: list = field(default_factory=list)
|
| 32 |
+
last_state: dict = None
|
| 33 |
+
best_state: dict = None
|
| 34 |
+
optimizer: dict = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main():
|
| 38 |
+
parser = get_parser()
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
name = get_name(parser, args)
|
| 41 |
+
print(f"Experiment {name}")
|
| 42 |
+
|
| 43 |
+
if args.musdb is None and args.rank == 0:
|
| 44 |
+
print("You must provide the path to the MusDB dataset with the --musdb flag. " "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", file=sys.stderr)
|
| 45 |
+
sys.exit(1)
|
| 46 |
+
|
| 47 |
+
eval_folder = args.evals / name
|
| 48 |
+
eval_folder.mkdir(exist_ok=True, parents=True)
|
| 49 |
+
args.logs.mkdir(exist_ok=True)
|
| 50 |
+
metrics_path = args.logs / f"{name}.json"
|
| 51 |
+
eval_folder.mkdir(exist_ok=True, parents=True)
|
| 52 |
+
args.checkpoints.mkdir(exist_ok=True, parents=True)
|
| 53 |
+
args.models.mkdir(exist_ok=True, parents=True)
|
| 54 |
+
|
| 55 |
+
if args.device is None:
|
| 56 |
+
device = "cpu"
|
| 57 |
+
if th.cuda.is_available():
|
| 58 |
+
device = "cuda"
|
| 59 |
+
else:
|
| 60 |
+
device = args.device
|
| 61 |
+
|
| 62 |
+
th.manual_seed(args.seed)
|
| 63 |
+
# Prevents too many threads to be started when running `museval` as it can be quite
|
| 64 |
+
# inefficient on NUMA architectures.
|
| 65 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 66 |
+
|
| 67 |
+
if args.world_size > 1:
|
| 68 |
+
if device != "cuda" and args.rank == 0:
|
| 69 |
+
print("Error: distributed training is only available with cuda device", file=sys.stderr)
|
| 70 |
+
sys.exit(1)
|
| 71 |
+
th.cuda.set_device(args.rank % th.cuda.device_count())
|
| 72 |
+
distributed.init_process_group(backend="nccl", init_method="tcp://" + args.master, rank=args.rank, world_size=args.world_size)
|
| 73 |
+
|
| 74 |
+
checkpoint = args.checkpoints / f"{name}.th"
|
| 75 |
+
checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
|
| 76 |
+
if args.restart and checkpoint.exists():
|
| 77 |
+
checkpoint.unlink()
|
| 78 |
+
|
| 79 |
+
if args.test:
|
| 80 |
+
args.epochs = 1
|
| 81 |
+
args.repeat = 0
|
| 82 |
+
model = load_model(args.models / args.test)
|
| 83 |
+
elif args.tasnet:
|
| 84 |
+
model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X)
|
| 85 |
+
else:
|
| 86 |
+
model = Demucs(
|
| 87 |
+
audio_channels=args.audio_channels,
|
| 88 |
+
channels=args.channels,
|
| 89 |
+
context=args.context,
|
| 90 |
+
depth=args.depth,
|
| 91 |
+
glu=args.glu,
|
| 92 |
+
growth=args.growth,
|
| 93 |
+
kernel_size=args.kernel_size,
|
| 94 |
+
lstm_layers=args.lstm_layers,
|
| 95 |
+
rescale=args.rescale,
|
| 96 |
+
rewrite=args.rewrite,
|
| 97 |
+
sources=4,
|
| 98 |
+
stride=args.conv_stride,
|
| 99 |
+
upsample=args.upsample,
|
| 100 |
+
samplerate=args.samplerate,
|
| 101 |
+
)
|
| 102 |
+
model.to(device)
|
| 103 |
+
if args.show:
|
| 104 |
+
print(model)
|
| 105 |
+
size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
|
| 106 |
+
print(f"Model size {size}")
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
saved = th.load(checkpoint, map_location="cpu")
|
| 113 |
+
except IOError:
|
| 114 |
+
saved = SavedState()
|
| 115 |
+
else:
|
| 116 |
+
model.load_state_dict(saved.last_state)
|
| 117 |
+
optimizer.load_state_dict(saved.optimizer)
|
| 118 |
+
|
| 119 |
+
if args.save_model:
|
| 120 |
+
if args.rank == 0:
|
| 121 |
+
model.to("cpu")
|
| 122 |
+
model.load_state_dict(saved.best_state)
|
| 123 |
+
save_model(model, args.models / f"{name}.th")
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
if args.rank == 0:
|
| 127 |
+
done = args.logs / f"{name}.done"
|
| 128 |
+
if done.exists():
|
| 129 |
+
done.unlink()
|
| 130 |
+
|
| 131 |
+
if args.augment:
|
| 132 |
+
augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride), Remix(group_size=args.remix_group_size)).to(device)
|
| 133 |
+
else:
|
| 134 |
+
augment = Shift(args.data_stride)
|
| 135 |
+
|
| 136 |
+
if args.mse:
|
| 137 |
+
criterion = nn.MSELoss()
|
| 138 |
+
else:
|
| 139 |
+
criterion = nn.L1Loss()
|
| 140 |
+
|
| 141 |
+
# Setting number of samples so that all convolution windows are full.
|
| 142 |
+
# Prevents hard to debug mistake with the prediction being shifted compared
|
| 143 |
+
# to the input mixture.
|
| 144 |
+
samples = model.valid_length(args.samples)
|
| 145 |
+
print(f"Number of training samples adjusted to {samples}")
|
| 146 |
+
|
| 147 |
+
if args.raw:
|
| 148 |
+
train_set = Rawset(args.raw / "train", samples=samples + args.data_stride, channels=args.audio_channels, streams=[0, 1, 2, 3, 4], stride=args.data_stride)
|
| 149 |
+
|
| 150 |
+
valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
|
| 151 |
+
else:
|
| 152 |
+
if not args.metadata.is_file() and args.rank == 0:
|
| 153 |
+
build_musdb_metadata(args.metadata, args.musdb, args.workers)
|
| 154 |
+
if args.world_size > 1:
|
| 155 |
+
distributed.barrier()
|
| 156 |
+
metadata = json.load(open(args.metadata))
|
| 157 |
+
duration = Fraction(samples + args.data_stride, args.samplerate)
|
| 158 |
+
stride = Fraction(args.data_stride, args.samplerate)
|
| 159 |
+
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), metadata, duration=duration, stride=stride, samplerate=args.samplerate, channels=args.audio_channels)
|
| 160 |
+
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), metadata, samplerate=args.samplerate, channels=args.audio_channels)
|
| 161 |
+
|
| 162 |
+
best_loss = float("inf")
|
| 163 |
+
for epoch, metrics in enumerate(saved.metrics):
|
| 164 |
+
print(f"Epoch {epoch:03d}: " f"train={metrics['train']:.8f} " f"valid={metrics['valid']:.8f} " f"best={metrics['best']:.4f} " f"duration={human_seconds(metrics['duration'])}")
|
| 165 |
+
best_loss = metrics["best"]
|
| 166 |
+
|
| 167 |
+
if args.world_size > 1:
|
| 168 |
+
dmodel = DistributedDataParallel(model, device_ids=[th.cuda.current_device()], output_device=th.cuda.current_device())
|
| 169 |
+
else:
|
| 170 |
+
dmodel = model
|
| 171 |
+
|
| 172 |
+
for epoch in range(len(saved.metrics), args.epochs):
|
| 173 |
+
begin = time.time()
|
| 174 |
+
model.train()
|
| 175 |
+
train_loss = train_model(
|
| 176 |
+
epoch, train_set, dmodel, criterion, optimizer, augment, batch_size=args.batch_size, device=device, repeat=args.repeat, seed=args.seed, workers=args.workers, world_size=args.world_size
|
| 177 |
+
)
|
| 178 |
+
model.eval()
|
| 179 |
+
valid_loss = validate_model(epoch, valid_set, model, criterion, device=device, rank=args.rank, split=args.split_valid, world_size=args.world_size)
|
| 180 |
+
|
| 181 |
+
duration = time.time() - begin
|
| 182 |
+
if valid_loss < best_loss:
|
| 183 |
+
best_loss = valid_loss
|
| 184 |
+
saved.best_state = {key: value.to("cpu").clone() for key, value in model.state_dict().items()}
|
| 185 |
+
saved.metrics.append({"train": train_loss, "valid": valid_loss, "best": best_loss, "duration": duration})
|
| 186 |
+
if args.rank == 0:
|
| 187 |
+
json.dump(saved.metrics, open(metrics_path, "w"))
|
| 188 |
+
|
| 189 |
+
saved.last_state = model.state_dict()
|
| 190 |
+
saved.optimizer = optimizer.state_dict()
|
| 191 |
+
if args.rank == 0 and not args.test:
|
| 192 |
+
th.save(saved, checkpoint_tmp)
|
| 193 |
+
checkpoint_tmp.rename(checkpoint)
|
| 194 |
+
|
| 195 |
+
print(f"Epoch {epoch:03d}: " f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} " f"duration={human_seconds(duration)}")
|
| 196 |
+
|
| 197 |
+
del dmodel
|
| 198 |
+
model.load_state_dict(saved.best_state)
|
| 199 |
+
if args.eval_cpu:
|
| 200 |
+
device = "cpu"
|
| 201 |
+
model.to(device)
|
| 202 |
+
model.eval()
|
| 203 |
+
evaluate(model, args.musdb, eval_folder, rank=args.rank, world_size=args.world_size, device=device, save=args.save, split=args.split_valid, shifts=args.shifts, workers=args.eval_workers)
|
| 204 |
+
model.to("cpu")
|
| 205 |
+
save_model(model, args.models / f"{name}.th")
|
| 206 |
+
if args.rank == 0:
|
| 207 |
+
print("done")
|
| 208 |
+
done.write_text("done")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
main()
|
audio_separator/separator/uvr_lib_v5/demucs/apply.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
Code to apply a model to a mix. It will handle chunking with overlaps and
|
| 8 |
+
inteprolation between chunks, as well as the "shift trick".
|
| 9 |
+
"""
|
| 10 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 11 |
+
import random
|
| 12 |
+
import typing as tp
|
| 13 |
+
|
| 14 |
+
import torch as th
|
| 15 |
+
from torch import nn
|
| 16 |
+
from torch.nn import functional as F
|
| 17 |
+
import tqdm
|
| 18 |
+
|
| 19 |
+
from .demucs import Demucs
|
| 20 |
+
from .hdemucs import HDemucs
|
| 21 |
+
from .utils import center_trim, DummyPoolExecutor
|
| 22 |
+
|
| 23 |
+
Model = tp.Union[Demucs, HDemucs]
|
| 24 |
+
|
| 25 |
+
progress_bar_num = 0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BagOfModels(nn.Module):
|
| 29 |
+
def __init__(self, models: tp.List[Model], weights: tp.Optional[tp.List[tp.List[float]]] = None, segment: tp.Optional[float] = None):
|
| 30 |
+
"""
|
| 31 |
+
Represents a bag of models with specific weights.
|
| 32 |
+
You should call `apply_model` rather than calling directly the forward here for
|
| 33 |
+
optimal performance.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
models (list[nn.Module]): list of Demucs/HDemucs models.
|
| 37 |
+
weights (list[list[float]]): list of weights. If None, assumed to
|
| 38 |
+
be all ones, otherwise it should be a list of N list (N number of models),
|
| 39 |
+
each containing S floats (S number of sources).
|
| 40 |
+
segment (None or float): overrides the `segment` attribute of each model
|
| 41 |
+
(this is performed inplace, be careful if you reuse the models passed).
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
super().__init__()
|
| 45 |
+
assert len(models) > 0
|
| 46 |
+
first = models[0]
|
| 47 |
+
for other in models:
|
| 48 |
+
assert other.sources == first.sources
|
| 49 |
+
assert other.samplerate == first.samplerate
|
| 50 |
+
assert other.audio_channels == first.audio_channels
|
| 51 |
+
if segment is not None:
|
| 52 |
+
other.segment = segment
|
| 53 |
+
|
| 54 |
+
self.audio_channels = first.audio_channels
|
| 55 |
+
self.samplerate = first.samplerate
|
| 56 |
+
self.sources = first.sources
|
| 57 |
+
self.models = nn.ModuleList(models)
|
| 58 |
+
|
| 59 |
+
if weights is None:
|
| 60 |
+
weights = [[1.0 for _ in first.sources] for _ in models]
|
| 61 |
+
else:
|
| 62 |
+
assert len(weights) == len(models)
|
| 63 |
+
for weight in weights:
|
| 64 |
+
assert len(weight) == len(first.sources)
|
| 65 |
+
self.weights = weights
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
raise NotImplementedError("Call `apply_model` on this.")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TensorChunk:
|
| 72 |
+
def __init__(self, tensor, offset=0, length=None):
|
| 73 |
+
total_length = tensor.shape[-1]
|
| 74 |
+
assert offset >= 0
|
| 75 |
+
assert offset < total_length
|
| 76 |
+
|
| 77 |
+
if length is None:
|
| 78 |
+
length = total_length - offset
|
| 79 |
+
else:
|
| 80 |
+
length = min(total_length - offset, length)
|
| 81 |
+
|
| 82 |
+
if isinstance(tensor, TensorChunk):
|
| 83 |
+
self.tensor = tensor.tensor
|
| 84 |
+
self.offset = offset + tensor.offset
|
| 85 |
+
else:
|
| 86 |
+
self.tensor = tensor
|
| 87 |
+
self.offset = offset
|
| 88 |
+
self.length = length
|
| 89 |
+
self.device = tensor.device
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def shape(self):
|
| 93 |
+
shape = list(self.tensor.shape)
|
| 94 |
+
shape[-1] = self.length
|
| 95 |
+
return shape
|
| 96 |
+
|
| 97 |
+
def padded(self, target_length):
|
| 98 |
+
delta = target_length - self.length
|
| 99 |
+
total_length = self.tensor.shape[-1]
|
| 100 |
+
assert delta >= 0
|
| 101 |
+
|
| 102 |
+
start = self.offset - delta // 2
|
| 103 |
+
end = start + target_length
|
| 104 |
+
|
| 105 |
+
correct_start = max(0, start)
|
| 106 |
+
correct_end = min(total_length, end)
|
| 107 |
+
|
| 108 |
+
pad_left = correct_start - start
|
| 109 |
+
pad_right = end - correct_end
|
| 110 |
+
|
| 111 |
+
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
| 112 |
+
assert out.shape[-1] == target_length
|
| 113 |
+
return out
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def tensor_chunk(tensor_or_chunk):
|
| 117 |
+
if isinstance(tensor_or_chunk, TensorChunk):
|
| 118 |
+
return tensor_or_chunk
|
| 119 |
+
else:
|
| 120 |
+
assert isinstance(tensor_or_chunk, th.Tensor)
|
| 121 |
+
return TensorChunk(tensor_or_chunk)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1.0, static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
|
| 125 |
+
"""
|
| 126 |
+
Apply model to a given mixture.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
| 130 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
| 131 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
| 132 |
+
and improves SDR by up to 0.2 points.
|
| 133 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
| 134 |
+
and predictions will be performed individually on each and concatenated.
|
| 135 |
+
Useful for model with large memory footprint like Tasnet.
|
| 136 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
| 137 |
+
device (torch.device, str, or None): if provided, device on which to
|
| 138 |
+
execute the computation, otherwise `mix.device` is assumed.
|
| 139 |
+
When `device` is different from `mix.device`, only local computations will
|
| 140 |
+
be on `device`, while the entire tracks will be stored on `mix.device`.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
global fut_length
|
| 144 |
+
global bag_num
|
| 145 |
+
global prog_bar
|
| 146 |
+
|
| 147 |
+
if device is None:
|
| 148 |
+
device = mix.device
|
| 149 |
+
else:
|
| 150 |
+
device = th.device(device)
|
| 151 |
+
if pool is None:
|
| 152 |
+
if num_workers > 0 and device.type == "cpu":
|
| 153 |
+
pool = ThreadPoolExecutor(num_workers)
|
| 154 |
+
else:
|
| 155 |
+
pool = DummyPoolExecutor()
|
| 156 |
+
|
| 157 |
+
kwargs = {
|
| 158 |
+
"shifts": shifts,
|
| 159 |
+
"split": split,
|
| 160 |
+
"overlap": overlap,
|
| 161 |
+
"transition_power": transition_power,
|
| 162 |
+
"progress": progress,
|
| 163 |
+
"device": device,
|
| 164 |
+
"pool": pool,
|
| 165 |
+
"set_progress_bar": set_progress_bar,
|
| 166 |
+
"static_shifts": static_shifts,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
if isinstance(model, BagOfModels):
|
| 170 |
+
# Special treatment for bag of model.
|
| 171 |
+
# We explicitely apply multiple times `apply_model` so that the random shifts
|
| 172 |
+
# are different for each model.
|
| 173 |
+
|
| 174 |
+
estimates = 0
|
| 175 |
+
totals = [0] * len(model.sources)
|
| 176 |
+
bag_num = len(model.models)
|
| 177 |
+
fut_length = 0
|
| 178 |
+
prog_bar = 0
|
| 179 |
+
current_model = 0 # (bag_num + 1)
|
| 180 |
+
for sub_model, weight in zip(model.models, model.weights):
|
| 181 |
+
original_model_device = next(iter(sub_model.parameters())).device
|
| 182 |
+
sub_model.to(device)
|
| 183 |
+
fut_length += fut_length
|
| 184 |
+
current_model += 1
|
| 185 |
+
out = apply_model(sub_model, mix, **kwargs)
|
| 186 |
+
sub_model.to(original_model_device)
|
| 187 |
+
for k, inst_weight in enumerate(weight):
|
| 188 |
+
out[:, k, :, :] *= inst_weight
|
| 189 |
+
totals[k] += inst_weight
|
| 190 |
+
estimates += out
|
| 191 |
+
del out
|
| 192 |
+
|
| 193 |
+
for k in range(estimates.shape[1]):
|
| 194 |
+
estimates[:, k, :, :] /= totals[k]
|
| 195 |
+
return estimates
|
| 196 |
+
|
| 197 |
+
model.to(device)
|
| 198 |
+
model.eval()
|
| 199 |
+
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
| 200 |
+
batch, channels, length = mix.shape
|
| 201 |
+
|
| 202 |
+
if shifts:
|
| 203 |
+
kwargs["shifts"] = 0
|
| 204 |
+
max_shift = int(0.5 * model.samplerate)
|
| 205 |
+
mix = tensor_chunk(mix)
|
| 206 |
+
padded_mix = mix.padded(length + 2 * max_shift)
|
| 207 |
+
out = 0
|
| 208 |
+
for _ in range(shifts):
|
| 209 |
+
offset = random.randint(0, max_shift)
|
| 210 |
+
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
| 211 |
+
shifted_out = apply_model(model, shifted, **kwargs)
|
| 212 |
+
out += shifted_out[..., max_shift - offset :]
|
| 213 |
+
out /= shifts
|
| 214 |
+
return out
|
| 215 |
+
elif split:
|
| 216 |
+
kwargs["split"] = False
|
| 217 |
+
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
|
| 218 |
+
sum_weight = th.zeros(length, device=mix.device)
|
| 219 |
+
segment = int(model.samplerate * model.segment)
|
| 220 |
+
stride = int((1 - overlap) * segment)
|
| 221 |
+
offsets = range(0, length, stride)
|
| 222 |
+
scale = float(format(stride / model.samplerate, ".2f"))
|
| 223 |
+
# We start from a triangle shaped weight, with maximal weight in the middle
|
| 224 |
+
# of the segment. Then we normalize and take to the power `transition_power`.
|
| 225 |
+
# Large values of transition power will lead to sharper transitions.
|
| 226 |
+
weight = th.cat([th.arange(1, segment // 2 + 1, device=device), th.arange(segment - segment // 2, 0, -1, device=device)])
|
| 227 |
+
assert len(weight) == segment
|
| 228 |
+
# If the overlap < 50%, this will translate to linear transition when
|
| 229 |
+
# transition_power is 1.
|
| 230 |
+
weight = (weight / weight.max()) ** transition_power
|
| 231 |
+
futures = []
|
| 232 |
+
for offset in offsets:
|
| 233 |
+
chunk = TensorChunk(mix, offset, segment)
|
| 234 |
+
future = pool.submit(apply_model, model, chunk, **kwargs)
|
| 235 |
+
futures.append((future, offset))
|
| 236 |
+
offset += segment
|
| 237 |
+
if progress:
|
| 238 |
+
futures = tqdm.tqdm(futures)
|
| 239 |
+
for future, offset in futures:
|
| 240 |
+
if set_progress_bar:
|
| 241 |
+
fut_length = len(futures) * bag_num * static_shifts
|
| 242 |
+
prog_bar += 1
|
| 243 |
+
set_progress_bar(0.1, (0.8 / fut_length * prog_bar))
|
| 244 |
+
chunk_out = future.result()
|
| 245 |
+
chunk_length = chunk_out.shape[-1]
|
| 246 |
+
out[..., offset : offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
|
| 247 |
+
sum_weight[offset : offset + segment] += weight[:chunk_length].to(mix.device)
|
| 248 |
+
assert sum_weight.min() > 0
|
| 249 |
+
out /= sum_weight
|
| 250 |
+
return out
|
| 251 |
+
else:
|
| 252 |
+
if hasattr(model, "valid_length"):
|
| 253 |
+
valid_length = model.valid_length(length)
|
| 254 |
+
else:
|
| 255 |
+
valid_length = length
|
| 256 |
+
mix = tensor_chunk(mix)
|
| 257 |
+
padded_mix = mix.padded(valid_length).to(device)
|
| 258 |
+
with th.no_grad():
|
| 259 |
+
out = model(padded_mix)
|
| 260 |
+
return center_trim(out, length)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def demucs_segments(demucs_segment, demucs_model):
|
| 264 |
+
|
| 265 |
+
if demucs_segment == "Default":
|
| 266 |
+
segment = None
|
| 267 |
+
if isinstance(demucs_model, BagOfModels):
|
| 268 |
+
if segment is not None:
|
| 269 |
+
for sub in demucs_model.models:
|
| 270 |
+
sub.segment = segment
|
| 271 |
+
else:
|
| 272 |
+
if segment is not None:
|
| 273 |
+
sub.segment = segment
|
| 274 |
+
else:
|
| 275 |
+
try:
|
| 276 |
+
segment = int(demucs_segment)
|
| 277 |
+
if isinstance(demucs_model, BagOfModels):
|
| 278 |
+
if segment is not None:
|
| 279 |
+
for sub in demucs_model.models:
|
| 280 |
+
sub.segment = segment
|
| 281 |
+
else:
|
| 282 |
+
if segment is not None:
|
| 283 |
+
sub.segment = segment
|
| 284 |
+
except:
|
| 285 |
+
segment = None
|
| 286 |
+
if isinstance(demucs_model, BagOfModels):
|
| 287 |
+
if segment is not None:
|
| 288 |
+
for sub in demucs_model.models:
|
| 289 |
+
sub.segment = segment
|
| 290 |
+
else:
|
| 291 |
+
if segment is not None:
|
| 292 |
+
sub.segment = segment
|
| 293 |
+
|
| 294 |
+
return demucs_model
|
audio_separator/separator/uvr_lib_v5/demucs/demucs.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import typing as tp
|
| 9 |
+
|
| 10 |
+
import julius
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
|
| 15 |
+
from .states import capture_init
|
| 16 |
+
from .utils import center_trim, unfold
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BLSTM(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
BiLSTM with same hidden units as input dim.
|
| 22 |
+
If `max_steps` is not None, input will be splitting in overlapping
|
| 23 |
+
chunks and the LSTM applied separately on each chunk.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
| 27 |
+
super().__init__()
|
| 28 |
+
assert max_steps is None or max_steps % 4 == 0
|
| 29 |
+
self.max_steps = max_steps
|
| 30 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
| 31 |
+
self.linear = nn.Linear(2 * dim, dim)
|
| 32 |
+
self.skip = skip
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
B, C, T = x.shape
|
| 36 |
+
y = x
|
| 37 |
+
framed = False
|
| 38 |
+
if self.max_steps is not None and T > self.max_steps:
|
| 39 |
+
width = self.max_steps
|
| 40 |
+
stride = width // 2
|
| 41 |
+
frames = unfold(x, width, stride)
|
| 42 |
+
nframes = frames.shape[2]
|
| 43 |
+
framed = True
|
| 44 |
+
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
| 45 |
+
|
| 46 |
+
x = x.permute(2, 0, 1)
|
| 47 |
+
|
| 48 |
+
x = self.lstm(x)[0]
|
| 49 |
+
x = self.linear(x)
|
| 50 |
+
x = x.permute(1, 2, 0)
|
| 51 |
+
if framed:
|
| 52 |
+
out = []
|
| 53 |
+
frames = x.reshape(B, -1, C, width)
|
| 54 |
+
limit = stride // 2
|
| 55 |
+
for k in range(nframes):
|
| 56 |
+
if k == 0:
|
| 57 |
+
out.append(frames[:, k, :, :-limit])
|
| 58 |
+
elif k == nframes - 1:
|
| 59 |
+
out.append(frames[:, k, :, limit:])
|
| 60 |
+
else:
|
| 61 |
+
out.append(frames[:, k, :, limit:-limit])
|
| 62 |
+
out = torch.cat(out, -1)
|
| 63 |
+
out = out[..., :T]
|
| 64 |
+
x = out
|
| 65 |
+
if self.skip:
|
| 66 |
+
x = x + y
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def rescale_conv(conv, reference):
|
| 71 |
+
"""Rescale initial weight scale. It is unclear why it helps but it certainly does."""
|
| 72 |
+
std = conv.weight.std().detach()
|
| 73 |
+
scale = (std / reference) ** 0.5
|
| 74 |
+
conv.weight.data /= scale
|
| 75 |
+
if conv.bias is not None:
|
| 76 |
+
conv.bias.data /= scale
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def rescale_module(module, reference):
|
| 80 |
+
for sub in module.modules():
|
| 81 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
|
| 82 |
+
rescale_conv(sub, reference)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class LayerScale(nn.Module):
|
| 86 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
| 87 |
+
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, channels: int, init: float = 0):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
| 93 |
+
self.scale.data[:] = init
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
return self.scale[:, None] * x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DConv(nn.Module):
|
| 100 |
+
"""
|
| 101 |
+
New residual branches in each encoder layer.
|
| 102 |
+
This alternates dilated convolutions, potentially with LSTMs and attention.
|
| 103 |
+
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
| 104 |
+
e.g. of dim `channels // compress`.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, kernel=3, dilate=True):
|
| 108 |
+
"""
|
| 109 |
+
Args:
|
| 110 |
+
channels: input/output channels for residual branch.
|
| 111 |
+
compress: amount of channel compression inside the branch.
|
| 112 |
+
depth: number of layers in the residual branch. Each layer has its own
|
| 113 |
+
projection, and potentially LSTM and attention.
|
| 114 |
+
init: initial scale for LayerNorm.
|
| 115 |
+
norm: use GroupNorm.
|
| 116 |
+
attn: use LocalAttention.
|
| 117 |
+
heads: number of heads for the LocalAttention.
|
| 118 |
+
ndecay: number of decay controls in the LocalAttention.
|
| 119 |
+
lstm: use LSTM.
|
| 120 |
+
gelu: Use GELU activation.
|
| 121 |
+
kernel: kernel size for the (dilated) convolutions.
|
| 122 |
+
dilate: if true, use dilation, increasing with the depth.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
super().__init__()
|
| 126 |
+
assert kernel % 2 == 1
|
| 127 |
+
self.channels = channels
|
| 128 |
+
self.compress = compress
|
| 129 |
+
self.depth = abs(depth)
|
| 130 |
+
dilate = depth > 0
|
| 131 |
+
|
| 132 |
+
norm_fn: tp.Callable[[int], nn.Module]
|
| 133 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 134 |
+
if norm:
|
| 135 |
+
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
| 136 |
+
|
| 137 |
+
hidden = int(channels / compress)
|
| 138 |
+
|
| 139 |
+
act: tp.Type[nn.Module]
|
| 140 |
+
if gelu:
|
| 141 |
+
act = nn.GELU
|
| 142 |
+
else:
|
| 143 |
+
act = nn.ReLU
|
| 144 |
+
|
| 145 |
+
self.layers = nn.ModuleList([])
|
| 146 |
+
for d in range(self.depth):
|
| 147 |
+
dilation = 2**d if dilate else 1
|
| 148 |
+
padding = dilation * (kernel // 2)
|
| 149 |
+
mods = [
|
| 150 |
+
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
|
| 151 |
+
norm_fn(hidden),
|
| 152 |
+
act(),
|
| 153 |
+
nn.Conv1d(hidden, 2 * channels, 1),
|
| 154 |
+
norm_fn(2 * channels),
|
| 155 |
+
nn.GLU(1),
|
| 156 |
+
LayerScale(channels, init),
|
| 157 |
+
]
|
| 158 |
+
if attn:
|
| 159 |
+
mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
|
| 160 |
+
if lstm:
|
| 161 |
+
mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
|
| 162 |
+
layer = nn.Sequential(*mods)
|
| 163 |
+
self.layers.append(layer)
|
| 164 |
+
|
| 165 |
+
def forward(self, x):
|
| 166 |
+
for layer in self.layers:
|
| 167 |
+
x = x + layer(x)
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class LocalState(nn.Module):
|
| 172 |
+
"""Local state allows to have attention based only on data (no positional embedding),
|
| 173 |
+
but while setting a constraint on the time window (e.g. decaying penalty term).
|
| 174 |
+
|
| 175 |
+
Also a failed experiments with trying to provide some frequency based attention.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
| 179 |
+
super().__init__()
|
| 180 |
+
assert channels % heads == 0, (channels, heads)
|
| 181 |
+
self.heads = heads
|
| 182 |
+
self.nfreqs = nfreqs
|
| 183 |
+
self.ndecay = ndecay
|
| 184 |
+
self.content = nn.Conv1d(channels, channels, 1)
|
| 185 |
+
self.query = nn.Conv1d(channels, channels, 1)
|
| 186 |
+
self.key = nn.Conv1d(channels, channels, 1)
|
| 187 |
+
if nfreqs:
|
| 188 |
+
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
| 189 |
+
if ndecay:
|
| 190 |
+
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
| 191 |
+
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
| 192 |
+
self.query_decay.weight.data *= 0.01
|
| 193 |
+
assert self.query_decay.bias is not None # stupid type checker
|
| 194 |
+
self.query_decay.bias.data[:] = -2
|
| 195 |
+
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
| 196 |
+
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
B, C, T = x.shape
|
| 199 |
+
heads = self.heads
|
| 200 |
+
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
| 201 |
+
# left index are keys, right index are queries
|
| 202 |
+
delta = indexes[:, None] - indexes[None, :]
|
| 203 |
+
|
| 204 |
+
queries = self.query(x).view(B, heads, -1, T)
|
| 205 |
+
keys = self.key(x).view(B, heads, -1, T)
|
| 206 |
+
# t are keys, s are queries
|
| 207 |
+
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
| 208 |
+
dots /= keys.shape[2] ** 0.5
|
| 209 |
+
if self.nfreqs:
|
| 210 |
+
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
| 211 |
+
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
| 212 |
+
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs**0.5
|
| 213 |
+
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
| 214 |
+
if self.ndecay:
|
| 215 |
+
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
| 216 |
+
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
| 217 |
+
decay_q = torch.sigmoid(decay_q) / 2
|
| 218 |
+
decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
|
| 219 |
+
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
| 220 |
+
|
| 221 |
+
# Kill self reference.
|
| 222 |
+
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
| 223 |
+
weights = torch.softmax(dots, dim=2)
|
| 224 |
+
|
| 225 |
+
content = self.content(x).view(B, heads, -1, T)
|
| 226 |
+
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
| 227 |
+
if self.nfreqs:
|
| 228 |
+
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
|
| 229 |
+
result = torch.cat([result, time_sig], 2)
|
| 230 |
+
result = result.reshape(B, -1, T)
|
| 231 |
+
return x + self.proj(result)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class Demucs(nn.Module):
|
| 235 |
+
@capture_init
|
| 236 |
+
def __init__(
|
| 237 |
+
self,
|
| 238 |
+
sources,
|
| 239 |
+
# Channels
|
| 240 |
+
audio_channels=2,
|
| 241 |
+
channels=64,
|
| 242 |
+
growth=2.0,
|
| 243 |
+
# Main structure
|
| 244 |
+
depth=6,
|
| 245 |
+
rewrite=True,
|
| 246 |
+
lstm_layers=0,
|
| 247 |
+
# Convolutions
|
| 248 |
+
kernel_size=8,
|
| 249 |
+
stride=4,
|
| 250 |
+
context=1,
|
| 251 |
+
# Activations
|
| 252 |
+
gelu=True,
|
| 253 |
+
glu=True,
|
| 254 |
+
# Normalization
|
| 255 |
+
norm_starts=4,
|
| 256 |
+
norm_groups=4,
|
| 257 |
+
# DConv residual branch
|
| 258 |
+
dconv_mode=1,
|
| 259 |
+
dconv_depth=2,
|
| 260 |
+
dconv_comp=4,
|
| 261 |
+
dconv_attn=4,
|
| 262 |
+
dconv_lstm=4,
|
| 263 |
+
dconv_init=1e-4,
|
| 264 |
+
# Pre/post processing
|
| 265 |
+
normalize=True,
|
| 266 |
+
resample=True,
|
| 267 |
+
# Weight init
|
| 268 |
+
rescale=0.1,
|
| 269 |
+
# Metadata
|
| 270 |
+
samplerate=44100,
|
| 271 |
+
segment=4 * 10,
|
| 272 |
+
):
|
| 273 |
+
"""
|
| 274 |
+
Args:
|
| 275 |
+
sources (list[str]): list of source names
|
| 276 |
+
audio_channels (int): stereo or mono
|
| 277 |
+
channels (int): first convolution channels
|
| 278 |
+
depth (int): number of encoder/decoder layers
|
| 279 |
+
growth (float): multiply (resp divide) number of channels by that
|
| 280 |
+
for each layer of the encoder (resp decoder)
|
| 281 |
+
depth (int): number of layers in the encoder and in the decoder.
|
| 282 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
| 283 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
|
| 284 |
+
by default, as this is now replaced by the smaller and faster small LSTMs
|
| 285 |
+
in the DConv branches.
|
| 286 |
+
kernel_size (int): kernel size for convolutions
|
| 287 |
+
stride (int): stride for convolutions
|
| 288 |
+
context (int): kernel size of the convolution in the
|
| 289 |
+
decoder before the transposed convolution. If > 1,
|
| 290 |
+
will provide some context from neighboring time steps.
|
| 291 |
+
gelu: use GELU activation function.
|
| 292 |
+
glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
|
| 293 |
+
norm_starts: layer at which group norm starts being used.
|
| 294 |
+
decoder layers are numbered in reverse order.
|
| 295 |
+
norm_groups: number of groups for group norm.
|
| 296 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
| 297 |
+
dconv_depth: depth of residual DConv branch.
|
| 298 |
+
dconv_comp: compression of DConv branch.
|
| 299 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
| 300 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
| 301 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
| 302 |
+
normalize (bool): normalizes the input audio on the fly, and scales back
|
| 303 |
+
the output by the same amount.
|
| 304 |
+
resample (bool): upsample x2 the input and downsample /2 the output.
|
| 305 |
+
rescale (int): rescale initial weights of convolutions
|
| 306 |
+
to get their standard deviation closer to `rescale`.
|
| 307 |
+
samplerate (int): stored as meta information for easing
|
| 308 |
+
future evaluations of the model.
|
| 309 |
+
segment (float): duration of the chunks of audio to ideally evaluate the model on.
|
| 310 |
+
This is used by `demucs.apply.apply_model`.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
super().__init__()
|
| 314 |
+
self.audio_channels = audio_channels
|
| 315 |
+
self.sources = sources
|
| 316 |
+
self.kernel_size = kernel_size
|
| 317 |
+
self.context = context
|
| 318 |
+
self.stride = stride
|
| 319 |
+
self.depth = depth
|
| 320 |
+
self.resample = resample
|
| 321 |
+
self.channels = channels
|
| 322 |
+
self.normalize = normalize
|
| 323 |
+
self.samplerate = samplerate
|
| 324 |
+
self.segment = segment
|
| 325 |
+
self.encoder = nn.ModuleList()
|
| 326 |
+
self.decoder = nn.ModuleList()
|
| 327 |
+
self.skip_scales = nn.ModuleList()
|
| 328 |
+
|
| 329 |
+
if glu:
|
| 330 |
+
activation = nn.GLU(dim=1)
|
| 331 |
+
ch_scale = 2
|
| 332 |
+
else:
|
| 333 |
+
activation = nn.ReLU()
|
| 334 |
+
ch_scale = 1
|
| 335 |
+
if gelu:
|
| 336 |
+
act2 = nn.GELU
|
| 337 |
+
else:
|
| 338 |
+
act2 = nn.ReLU
|
| 339 |
+
|
| 340 |
+
in_channels = audio_channels
|
| 341 |
+
padding = 0
|
| 342 |
+
for index in range(depth):
|
| 343 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 344 |
+
if index >= norm_starts:
|
| 345 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 346 |
+
|
| 347 |
+
encode = []
|
| 348 |
+
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), norm_fn(channels), act2()]
|
| 349 |
+
attn = index >= dconv_attn
|
| 350 |
+
lstm = index >= dconv_lstm
|
| 351 |
+
if dconv_mode & 1:
|
| 352 |
+
encode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
|
| 353 |
+
if rewrite:
|
| 354 |
+
encode += [nn.Conv1d(channels, ch_scale * channels, 1), norm_fn(ch_scale * channels), activation]
|
| 355 |
+
self.encoder.append(nn.Sequential(*encode))
|
| 356 |
+
|
| 357 |
+
decode = []
|
| 358 |
+
if index > 0:
|
| 359 |
+
out_channels = in_channels
|
| 360 |
+
else:
|
| 361 |
+
out_channels = len(self.sources) * audio_channels
|
| 362 |
+
if rewrite:
|
| 363 |
+
decode += [nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), norm_fn(ch_scale * channels), activation]
|
| 364 |
+
if dconv_mode & 2:
|
| 365 |
+
decode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
|
| 366 |
+
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride, padding=padding)]
|
| 367 |
+
if index > 0:
|
| 368 |
+
decode += [norm_fn(out_channels), act2()]
|
| 369 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
| 370 |
+
in_channels = channels
|
| 371 |
+
channels = int(growth * channels)
|
| 372 |
+
|
| 373 |
+
channels = in_channels
|
| 374 |
+
if lstm_layers:
|
| 375 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
| 376 |
+
else:
|
| 377 |
+
self.lstm = None
|
| 378 |
+
|
| 379 |
+
if rescale:
|
| 380 |
+
rescale_module(self, reference=rescale)
|
| 381 |
+
|
| 382 |
+
def valid_length(self, length):
|
| 383 |
+
"""
|
| 384 |
+
Return the nearest valid length to use with the model so that
|
| 385 |
+
there is no time steps left over in a convolution, e.g. for all
|
| 386 |
+
layers, size of the input - kernel_size % stride = 0.
|
| 387 |
+
|
| 388 |
+
Note that input are automatically padded if necessary to ensure that the output
|
| 389 |
+
has the same length as the input.
|
| 390 |
+
"""
|
| 391 |
+
if self.resample:
|
| 392 |
+
length *= 2
|
| 393 |
+
|
| 394 |
+
for _ in range(self.depth):
|
| 395 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
| 396 |
+
length = max(1, length)
|
| 397 |
+
|
| 398 |
+
for idx in range(self.depth):
|
| 399 |
+
length = (length - 1) * self.stride + self.kernel_size
|
| 400 |
+
|
| 401 |
+
if self.resample:
|
| 402 |
+
length = math.ceil(length / 2)
|
| 403 |
+
return int(length)
|
| 404 |
+
|
| 405 |
+
def forward(self, mix):
|
| 406 |
+
x = mix
|
| 407 |
+
length = x.shape[-1]
|
| 408 |
+
|
| 409 |
+
if self.normalize:
|
| 410 |
+
mono = mix.mean(dim=1, keepdim=True)
|
| 411 |
+
mean = mono.mean(dim=-1, keepdim=True)
|
| 412 |
+
std = mono.std(dim=-1, keepdim=True)
|
| 413 |
+
x = (x - mean) / (1e-5 + std)
|
| 414 |
+
else:
|
| 415 |
+
mean = 0
|
| 416 |
+
std = 1
|
| 417 |
+
|
| 418 |
+
delta = self.valid_length(length) - length
|
| 419 |
+
x = F.pad(x, (delta // 2, delta - delta // 2))
|
| 420 |
+
|
| 421 |
+
if self.resample:
|
| 422 |
+
x = julius.resample_frac(x, 1, 2)
|
| 423 |
+
|
| 424 |
+
saved = []
|
| 425 |
+
for encode in self.encoder:
|
| 426 |
+
x = encode(x)
|
| 427 |
+
saved.append(x)
|
| 428 |
+
|
| 429 |
+
if self.lstm:
|
| 430 |
+
x = self.lstm(x)
|
| 431 |
+
|
| 432 |
+
for decode in self.decoder:
|
| 433 |
+
skip = saved.pop(-1)
|
| 434 |
+
skip = center_trim(skip, x)
|
| 435 |
+
x = decode(x + skip)
|
| 436 |
+
|
| 437 |
+
if self.resample:
|
| 438 |
+
x = julius.resample_frac(x, 2, 1)
|
| 439 |
+
x = x * std + mean
|
| 440 |
+
x = center_trim(x, length)
|
| 441 |
+
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
| 442 |
+
return x
|
| 443 |
+
|
| 444 |
+
def load_state_dict(self, state, strict=True):
|
| 445 |
+
# fix a mismatch with previous generation Demucs models.
|
| 446 |
+
for idx in range(self.depth):
|
| 447 |
+
for a in ["encoder", "decoder"]:
|
| 448 |
+
for b in ["bias", "weight"]:
|
| 449 |
+
new = f"{a}.{idx}.3.{b}"
|
| 450 |
+
old = f"{a}.{idx}.2.{b}"
|
| 451 |
+
if old in state and new not in state:
|
| 452 |
+
state[new] = state.pop(old)
|
| 453 |
+
super().load_state_dict(state, strict=strict)
|
audio_separator/separator/uvr_lib_v5/demucs/filtering.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def atan2(y, x):
|
| 9 |
+
r"""Element-wise arctangent function of y/x.
|
| 10 |
+
Returns a new tensor with signed angles in radians.
|
| 11 |
+
It is an alternative implementation of torch.atan2
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
y (Tensor): First input tensor
|
| 15 |
+
x (Tensor): Second input tensor [shape=y.shape]
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Tensor: [shape=y.shape].
|
| 19 |
+
"""
|
| 20 |
+
pi = 2 * torch.asin(torch.tensor(1.0))
|
| 21 |
+
x += ((x == 0) & (y == 0)) * 1.0
|
| 22 |
+
out = torch.atan(y / x)
|
| 23 |
+
out += ((y >= 0) & (x < 0)) * pi
|
| 24 |
+
out -= ((y < 0) & (x < 0)) * pi
|
| 25 |
+
out *= 1 - ((y > 0) & (x == 0)) * 1.0
|
| 26 |
+
out += ((y > 0) & (x == 0)) * (pi / 2)
|
| 27 |
+
out *= 1 - ((y < 0) & (x == 0)) * 1.0
|
| 28 |
+
out += ((y < 0) & (x == 0)) * (-pi / 2)
|
| 29 |
+
return out
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Define basic complex operations on torch.Tensor objects whose last dimension
|
| 33 |
+
# consists in the concatenation of the real and imaginary parts.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _norm(x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
r"""Computes the norm value of a torch Tensor, assuming that it
|
| 38 |
+
comes as real and imaginary part in its last dimension.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
x (Tensor): Input Tensor of shape [shape=(..., 2)]
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Tensor: shape as x excluding the last dimension.
|
| 45 |
+
"""
|
| 46 |
+
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 50 |
+
"""Element-wise multiplication of two complex Tensors described
|
| 51 |
+
through their real and imaginary parts.
|
| 52 |
+
The result is added to the `out` tensor"""
|
| 53 |
+
|
| 54 |
+
# check `out` and allocate it if needed
|
| 55 |
+
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
| 56 |
+
if out is None or out.shape != target_shape:
|
| 57 |
+
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
| 58 |
+
if out is a:
|
| 59 |
+
real_a = a[..., 0]
|
| 60 |
+
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
|
| 61 |
+
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
|
| 62 |
+
else:
|
| 63 |
+
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
|
| 64 |
+
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 69 |
+
"""Element-wise multiplication of two complex Tensors described
|
| 70 |
+
through their real and imaginary parts
|
| 71 |
+
can work in place in case out is a only"""
|
| 72 |
+
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
| 73 |
+
if out is None or out.shape != target_shape:
|
| 74 |
+
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
| 75 |
+
if out is a:
|
| 76 |
+
real_a = a[..., 0]
|
| 77 |
+
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
|
| 78 |
+
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
|
| 79 |
+
else:
|
| 80 |
+
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
|
| 81 |
+
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 86 |
+
"""Element-wise multiplicative inverse of a Tensor with complex
|
| 87 |
+
entries described through their real and imaginary parts.
|
| 88 |
+
can work in place in case out is z"""
|
| 89 |
+
ez = _norm(z)
|
| 90 |
+
if out is None or out.shape != z.shape:
|
| 91 |
+
out = torch.zeros_like(z)
|
| 92 |
+
out[..., 0] = z[..., 0] / ez
|
| 93 |
+
out[..., 1] = -z[..., 1] / ez
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 98 |
+
"""Element-wise complex conjugate of a Tensor with complex entries
|
| 99 |
+
described through their real and imaginary parts.
|
| 100 |
+
can work in place in case out is z"""
|
| 101 |
+
if out is None or out.shape != z.shape:
|
| 102 |
+
out = torch.zeros_like(z)
|
| 103 |
+
out[..., 0] = z[..., 0]
|
| 104 |
+
out[..., 1] = -z[..., 1]
|
| 105 |
+
return out
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
Invert 1x1 or 2x2 matrices
|
| 111 |
+
|
| 112 |
+
Will generate errors if the matrices are singular: user must handle this
|
| 113 |
+
through his own regularization schemes.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
|
| 117 |
+
matrices to invert: must be square along dimensions -3 and -2
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
invM (Tensor): [shape=M.shape]
|
| 121 |
+
inverses of M
|
| 122 |
+
"""
|
| 123 |
+
nb_channels = M.shape[-2]
|
| 124 |
+
|
| 125 |
+
if out is None or out.shape != M.shape:
|
| 126 |
+
out = torch.empty_like(M)
|
| 127 |
+
|
| 128 |
+
if nb_channels == 1:
|
| 129 |
+
# scalar case
|
| 130 |
+
out = _inv(M, out)
|
| 131 |
+
elif nb_channels == 2:
|
| 132 |
+
# two channels case: analytical expression
|
| 133 |
+
|
| 134 |
+
# first compute the determinent
|
| 135 |
+
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
|
| 136 |
+
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
|
| 137 |
+
# invert it
|
| 138 |
+
invDet = _inv(det)
|
| 139 |
+
|
| 140 |
+
# then fill out the matrix with the inverse
|
| 141 |
+
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
|
| 142 |
+
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
|
| 143 |
+
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
|
| 144 |
+
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
|
| 145 |
+
else:
|
| 146 |
+
raise Exception("Only 2 channels are supported for the torch version.")
|
| 147 |
+
return out
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# Now define the signal-processing low-level functions used by the Separator
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def expectation_maximization(y: torch.Tensor, x: torch.Tensor, iterations: int = 2, eps: float = 1e-10, batch_size: int = 200):
|
| 154 |
+
r"""Expectation maximization algorithm, for refining source separation
|
| 155 |
+
estimates.
|
| 156 |
+
|
| 157 |
+
This algorithm allows to make source separation results better by
|
| 158 |
+
enforcing multichannel consistency for the estimates. This usually means
|
| 159 |
+
a better perceptual quality in terms of spatial artifacts.
|
| 160 |
+
|
| 161 |
+
The implementation follows the details presented in [1]_, taking
|
| 162 |
+
inspiration from the original EM algorithm proposed in [2]_ and its
|
| 163 |
+
weighted refinement proposed in [3]_, [4]_.
|
| 164 |
+
It works by iteratively:
|
| 165 |
+
|
| 166 |
+
* Re-estimate source parameters (power spectral densities and spatial
|
| 167 |
+
covariance matrices) through :func:`get_local_gaussian_model`.
|
| 168 |
+
|
| 169 |
+
* Separate again the mixture with the new parameters by first computing
|
| 170 |
+
the new modelled mixture covariance matrices with :func:`get_mix_model`,
|
| 171 |
+
prepare the Wiener filters through :func:`wiener_gain` and apply them
|
| 172 |
+
with :func:`apply_filter``.
|
| 173 |
+
|
| 174 |
+
References
|
| 175 |
+
----------
|
| 176 |
+
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
|
| 177 |
+
N. Takahashi and Y. Mitsufuji, "Improving music source separation based
|
| 178 |
+
on deep neural networks through data augmentation and network
|
| 179 |
+
blending." 2017 IEEE International Conference on Acoustics, Speech
|
| 180 |
+
and Signal Processing (ICASSP). IEEE, 2017.
|
| 181 |
+
|
| 182 |
+
.. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined
|
| 183 |
+
reverberant audio source separation using a full-rank spatial
|
| 184 |
+
covariance model." IEEE Transactions on Audio, Speech, and Language
|
| 185 |
+
Processing 18.7 (2010): 1830-1840.
|
| 186 |
+
|
| 187 |
+
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
|
| 188 |
+
separation with deep neural networks." IEEE/ACM Transactions on Audio,
|
| 189 |
+
Speech, and Language Processing 24.9 (2016): 1652-1664.
|
| 190 |
+
|
| 191 |
+
.. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
|
| 192 |
+
separation with deep neural networks." 2016 24th European Signal
|
| 193 |
+
Processing Conference (EUSIPCO). IEEE, 2016.
|
| 194 |
+
|
| 195 |
+
.. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
|
| 196 |
+
source separation." IEEE Transactions on Signal Processing
|
| 197 |
+
62.16 (2014): 4298-4310.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
|
| 201 |
+
initial estimates for the sources
|
| 202 |
+
x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
|
| 203 |
+
complex STFT of the mixture signal
|
| 204 |
+
iterations (int): [scalar]
|
| 205 |
+
number of iterations for the EM algorithm.
|
| 206 |
+
eps (float or None): [scalar]
|
| 207 |
+
The epsilon value to use for regularization and filters.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
|
| 211 |
+
estimated sources after iterations
|
| 212 |
+
v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
|
| 213 |
+
estimated power spectral densities
|
| 214 |
+
R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
|
| 215 |
+
estimated spatial covariance matrices
|
| 216 |
+
|
| 217 |
+
Notes:
|
| 218 |
+
* You need an initial estimate for the sources to apply this
|
| 219 |
+
algorithm. This is precisely what the :func:`wiener` function does.
|
| 220 |
+
* This algorithm *is not* an implementation of the "exact" EM
|
| 221 |
+
proposed in [1]_. In particular, it does compute the posterior
|
| 222 |
+
covariance matrices the same (exact) way. Instead, it uses the
|
| 223 |
+
simplified approximate scheme initially proposed in [5]_ and further
|
| 224 |
+
refined in [3]_, [4]_, that boils down to just take the empirical
|
| 225 |
+
covariance of the recent source estimates, followed by a weighted
|
| 226 |
+
average for the update of the spatial covariance matrix. It has been
|
| 227 |
+
empirically demonstrated that this simplified algorithm is more
|
| 228 |
+
robust for music separation.
|
| 229 |
+
|
| 230 |
+
Warning:
|
| 231 |
+
It is *very* important to make sure `x.dtype` is `torch.float64`
|
| 232 |
+
if you want double precision, because this function will **not**
|
| 233 |
+
do such conversion for you from `torch.complex32`, in case you want the
|
| 234 |
+
smaller RAM usage on purpose.
|
| 235 |
+
|
| 236 |
+
It is usually always better in terms of quality to have double
|
| 237 |
+
precision, by e.g. calling :func:`expectation_maximization`
|
| 238 |
+
with ``x.to(torch.float64)``.
|
| 239 |
+
"""
|
| 240 |
+
# dimensions
|
| 241 |
+
(nb_frames, nb_bins, nb_channels) = x.shape[:-1]
|
| 242 |
+
nb_sources = y.shape[-1]
|
| 243 |
+
|
| 244 |
+
regularization = torch.cat((torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device)), dim=2)
|
| 245 |
+
regularization = torch.sqrt(torch.as_tensor(eps)) * (regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)))
|
| 246 |
+
|
| 247 |
+
# allocate the spatial covariance matrices
|
| 248 |
+
R = [torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) for j in range(nb_sources)]
|
| 249 |
+
weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
|
| 250 |
+
|
| 251 |
+
v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
|
| 252 |
+
for it in range(iterations):
|
| 253 |
+
# constructing the mixture covariance matrix. Doing it with a loop
|
| 254 |
+
# to avoid storing anytime in RAM the whole 6D tensor
|
| 255 |
+
|
| 256 |
+
# update the PSD as the average spectrogram over channels
|
| 257 |
+
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
|
| 258 |
+
|
| 259 |
+
# update spatial covariance matrices (weighted update)
|
| 260 |
+
for j in range(nb_sources):
|
| 261 |
+
R[j] = torch.tensor(0.0, device=x.device)
|
| 262 |
+
weight = torch.tensor(eps, device=x.device)
|
| 263 |
+
pos: int = 0
|
| 264 |
+
batch_size = batch_size if batch_size else nb_frames
|
| 265 |
+
while pos < nb_frames:
|
| 266 |
+
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
| 267 |
+
pos = int(t[-1]) + 1
|
| 268 |
+
|
| 269 |
+
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
|
| 270 |
+
weight = weight + torch.sum(v[t, ..., j], dim=0)
|
| 271 |
+
R[j] = R[j] / weight[..., None, None, None]
|
| 272 |
+
weight = torch.zeros_like(weight)
|
| 273 |
+
|
| 274 |
+
# cloning y if we track gradient, because we're going to update it
|
| 275 |
+
if y.requires_grad:
|
| 276 |
+
y = y.clone()
|
| 277 |
+
|
| 278 |
+
pos = 0
|
| 279 |
+
while pos < nb_frames:
|
| 280 |
+
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
| 281 |
+
pos = int(t[-1]) + 1
|
| 282 |
+
|
| 283 |
+
y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
| 284 |
+
|
| 285 |
+
# compute mix covariance matrix
|
| 286 |
+
Cxx = regularization
|
| 287 |
+
for j in range(nb_sources):
|
| 288 |
+
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
|
| 289 |
+
|
| 290 |
+
# invert it
|
| 291 |
+
inv_Cxx = _invert(Cxx)
|
| 292 |
+
|
| 293 |
+
# separate the sources
|
| 294 |
+
for j in range(nb_sources):
|
| 295 |
+
|
| 296 |
+
# create a wiener gain for this source
|
| 297 |
+
gain = torch.zeros_like(inv_Cxx)
|
| 298 |
+
|
| 299 |
+
# computes multichannel Wiener gain as v_j R_j inv_Cxx
|
| 300 |
+
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels), torch.arange(nb_channels))
|
| 301 |
+
for index in indices:
|
| 302 |
+
gain[:, :, index[0], index[1], :] = _mul_add(R[j][None, :, index[0], index[2], :].clone(), inv_Cxx[:, :, index[2], index[1], :], gain[:, :, index[0], index[1], :])
|
| 303 |
+
gain = gain * v[t, ..., None, None, None, j]
|
| 304 |
+
|
| 305 |
+
# apply it to the mixture
|
| 306 |
+
for i in range(nb_channels):
|
| 307 |
+
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
|
| 308 |
+
|
| 309 |
+
return y, v, R
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def wiener(targets_spectrograms: torch.Tensor, mix_stft: torch.Tensor, iterations: int = 1, softmask: bool = False, residual: bool = False, scale_factor: float = 10.0, eps: float = 1e-10):
|
| 313 |
+
"""Wiener-based separation for multichannel audio.
|
| 314 |
+
|
| 315 |
+
The method uses the (possibly multichannel) spectrograms of the
|
| 316 |
+
sources to separate the (complex) Short Term Fourier Transform of the
|
| 317 |
+
mix. Separation is done in a sequential way by:
|
| 318 |
+
|
| 319 |
+
* Getting an initial estimate. This can be done in two ways: either by
|
| 320 |
+
directly using the spectrograms with the mixture phase, or
|
| 321 |
+
by using a softmasking strategy. This initial phase is controlled
|
| 322 |
+
by the `softmask` flag.
|
| 323 |
+
|
| 324 |
+
* If required, adding an additional residual target as the mix minus
|
| 325 |
+
all targets.
|
| 326 |
+
|
| 327 |
+
* Refinining these initial estimates through a call to
|
| 328 |
+
:func:`expectation_maximization` if the number of iterations is nonzero.
|
| 329 |
+
|
| 330 |
+
This implementation also allows to specify the epsilon value used for
|
| 331 |
+
regularization. It is based on [1]_, [2]_, [3]_, [4]_.
|
| 332 |
+
|
| 333 |
+
References
|
| 334 |
+
----------
|
| 335 |
+
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
|
| 336 |
+
N. Takahashi and Y. Mitsufuji, "Improving music source separation based
|
| 337 |
+
on deep neural networks through data augmentation and network
|
| 338 |
+
blending." 2017 IEEE International Conference on Acoustics, Speech
|
| 339 |
+
and Signal Processing (ICASSP). IEEE, 2017.
|
| 340 |
+
|
| 341 |
+
.. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
|
| 342 |
+
separation with deep neural networks." IEEE/ACM Transactions on Audio,
|
| 343 |
+
Speech, and Language Processing 24.9 (2016): 1652-1664.
|
| 344 |
+
|
| 345 |
+
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
|
| 346 |
+
separation with deep neural networks." 2016 24th European Signal
|
| 347 |
+
Processing Conference (EUSIPCO). IEEE, 2016.
|
| 348 |
+
|
| 349 |
+
.. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
|
| 350 |
+
source separation." IEEE Transactions on Signal Processing
|
| 351 |
+
62.16 (2014): 4298-4310.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
targets_spectrograms (Tensor): spectrograms of the sources
|
| 355 |
+
[shape=(nb_frames, nb_bins, nb_channels, nb_sources)].
|
| 356 |
+
This is a nonnegative tensor that is
|
| 357 |
+
usually the output of the actual separation method of the user. The
|
| 358 |
+
spectrograms may be mono, but they need to be 4-dimensional in all
|
| 359 |
+
cases.
|
| 360 |
+
mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)]
|
| 361 |
+
STFT of the mixture signal.
|
| 362 |
+
iterations (int): [scalar]
|
| 363 |
+
number of iterations for the EM algorithm
|
| 364 |
+
softmask (bool): Describes how the initial estimates are obtained.
|
| 365 |
+
* if `False`, then the mixture phase will directly be used with the
|
| 366 |
+
spectrogram as initial estimates.
|
| 367 |
+
* if `True`, initial estimates are obtained by multiplying the
|
| 368 |
+
complex mix element-wise with the ratio of each target spectrogram
|
| 369 |
+
with the sum of them all. This strategy is better if the model are
|
| 370 |
+
not really good, and worse otherwise.
|
| 371 |
+
residual (bool): if `True`, an additional target is created, which is
|
| 372 |
+
equal to the mixture minus the other targets, before application of
|
| 373 |
+
expectation maximization
|
| 374 |
+
eps (float): Epsilon value to use for computing the separations.
|
| 375 |
+
This is used whenever division with a model energy is
|
| 376 |
+
performed, i.e. when softmasking and when iterating the EM.
|
| 377 |
+
It can be understood as the energy of the additional white noise
|
| 378 |
+
that is taken out when separating.
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
|
| 382 |
+
STFT of estimated sources
|
| 383 |
+
|
| 384 |
+
Notes:
|
| 385 |
+
* Be careful that you need *magnitude spectrogram estimates* for the
|
| 386 |
+
case `softmask==False`.
|
| 387 |
+
* `softmask=False` is recommended
|
| 388 |
+
* The epsilon value will have a huge impact on performance. If it's
|
| 389 |
+
large, only the parts of the signal with a significant energy will
|
| 390 |
+
be kept in the sources. This epsilon then directly controls the
|
| 391 |
+
energy of the reconstruction error.
|
| 392 |
+
|
| 393 |
+
Warning:
|
| 394 |
+
As in :func:`expectation_maximization`, we recommend converting the
|
| 395 |
+
mixture `x` to double precision `torch.float64` *before* calling
|
| 396 |
+
:func:`wiener`.
|
| 397 |
+
"""
|
| 398 |
+
if softmask:
|
| 399 |
+
# if we use softmask, we compute the ratio mask for all targets and
|
| 400 |
+
# multiply by the mix stft
|
| 401 |
+
y = mix_stft[..., None] * (targets_spectrograms / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)))[..., None, :]
|
| 402 |
+
else:
|
| 403 |
+
# otherwise, we just multiply the targets spectrograms with mix phase
|
| 404 |
+
# we tacitly assume that we have magnitude estimates.
|
| 405 |
+
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
|
| 406 |
+
nb_sources = targets_spectrograms.shape[-1]
|
| 407 |
+
y = torch.zeros(mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device)
|
| 408 |
+
y[..., 0, :] = targets_spectrograms * torch.cos(angle)
|
| 409 |
+
y[..., 1, :] = targets_spectrograms * torch.sin(angle)
|
| 410 |
+
|
| 411 |
+
if residual:
|
| 412 |
+
# if required, adding an additional target as the mix minus
|
| 413 |
+
# available targets
|
| 414 |
+
y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
|
| 415 |
+
|
| 416 |
+
if iterations == 0:
|
| 417 |
+
return y
|
| 418 |
+
|
| 419 |
+
# we need to refine the estimates. Scales down the estimates for
|
| 420 |
+
# numerical stability
|
| 421 |
+
max_abs = torch.max(torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), torch.sqrt(_norm(mix_stft)).max() / scale_factor)
|
| 422 |
+
|
| 423 |
+
mix_stft = mix_stft / max_abs
|
| 424 |
+
y = y / max_abs
|
| 425 |
+
|
| 426 |
+
# call expectation maximization
|
| 427 |
+
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
|
| 428 |
+
|
| 429 |
+
# scale estimates up again
|
| 430 |
+
y = y * max_abs
|
| 431 |
+
return y
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _covariance(y_j):
|
| 435 |
+
"""
|
| 436 |
+
Compute the empirical covariance for a source.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
y_j (Tensor): complex stft of the source.
|
| 440 |
+
[shape=(nb_frames, nb_bins, nb_channels, 2)].
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
|
| 444 |
+
just y_j * conj(y_j.T): empirical covariance for each TF bin.
|
| 445 |
+
"""
|
| 446 |
+
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
|
| 447 |
+
Cj = torch.zeros((nb_frames, nb_bins, nb_channels, nb_channels, 2), dtype=y_j.dtype, device=y_j.device)
|
| 448 |
+
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
|
| 449 |
+
for index in indices:
|
| 450 |
+
Cj[:, :, index[0], index[1], :] = _mul_add(y_j[:, :, index[0], :], _conj(y_j[:, :, index[1], :]), Cj[:, :, index[0], index[1], :])
|
| 451 |
+
return Cj
|
audio_separator/separator/uvr_lib_v5/demucs/hdemucs.py
ADDED
|
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
This code contains the spectrogram and Hybrid version of Demucs.
|
| 8 |
+
"""
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
import math
|
| 11 |
+
import typing as tp
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
from .filtering import wiener
|
| 16 |
+
from .demucs import DConv, rescale_module
|
| 17 |
+
from .states import capture_init
|
| 18 |
+
from .spec import spectro, ispectro
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = "constant", value: float = 0.0):
|
| 22 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 23 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen."""
|
| 24 |
+
x0 = x
|
| 25 |
+
length = x.shape[-1]
|
| 26 |
+
padding_left, padding_right = paddings
|
| 27 |
+
if mode == "reflect":
|
| 28 |
+
max_pad = max(padding_left, padding_right)
|
| 29 |
+
if length <= max_pad:
|
| 30 |
+
extra_pad = max_pad - length + 1
|
| 31 |
+
extra_pad_right = min(padding_right, extra_pad)
|
| 32 |
+
extra_pad_left = extra_pad - extra_pad_right
|
| 33 |
+
paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
|
| 34 |
+
x = F.pad(x, (extra_pad_left, extra_pad_right))
|
| 35 |
+
out = F.pad(x, paddings, mode, value)
|
| 36 |
+
assert out.shape[-1] == length + padding_left + padding_right
|
| 37 |
+
assert (out[..., padding_left : padding_left + length] == x0).all()
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ScaledEmbedding(nn.Module):
|
| 42 |
+
"""
|
| 43 |
+
Boost learning rate for embeddings (with `scale`).
|
| 44 |
+
Also, can make embeddings continuous with `smooth`.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth=False):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 50 |
+
if smooth:
|
| 51 |
+
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
| 52 |
+
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
|
| 53 |
+
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
| 54 |
+
self.embedding.weight.data[:] = weight
|
| 55 |
+
self.embedding.weight.data /= scale
|
| 56 |
+
self.scale = scale
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def weight(self):
|
| 60 |
+
return self.embedding.weight * self.scale
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
out = self.embedding(x) * self.scale
|
| 64 |
+
return out
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class HEncLayer(nn.Module):
|
| 68 |
+
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, rewrite=True):
|
| 69 |
+
"""Encoder layer. This used both by the time and the frequency branch.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
chin: number of input channels.
|
| 73 |
+
chout: number of output channels.
|
| 74 |
+
norm_groups: number of groups for group norm.
|
| 75 |
+
empty: used to make a layer with just the first conv. this is used
|
| 76 |
+
before merging the time and freq. branches.
|
| 77 |
+
freq: this is acting on frequencies.
|
| 78 |
+
dconv: insert DConv residual branches.
|
| 79 |
+
norm: use GroupNorm.
|
| 80 |
+
context: context size for the 1x1 conv.
|
| 81 |
+
dconv_kw: list of kwargs for the DConv class.
|
| 82 |
+
pad: pad the input. Padding is done so that the output size is
|
| 83 |
+
always the input size / stride.
|
| 84 |
+
rewrite: add 1x1 conv at the end of the layer.
|
| 85 |
+
"""
|
| 86 |
+
super().__init__()
|
| 87 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 88 |
+
if norm:
|
| 89 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 90 |
+
if pad:
|
| 91 |
+
pad = kernel_size // 4
|
| 92 |
+
else:
|
| 93 |
+
pad = 0
|
| 94 |
+
klass = nn.Conv1d
|
| 95 |
+
self.freq = freq
|
| 96 |
+
self.kernel_size = kernel_size
|
| 97 |
+
self.stride = stride
|
| 98 |
+
self.empty = empty
|
| 99 |
+
self.norm = norm
|
| 100 |
+
self.pad = pad
|
| 101 |
+
if freq:
|
| 102 |
+
kernel_size = [kernel_size, 1]
|
| 103 |
+
stride = [stride, 1]
|
| 104 |
+
pad = [pad, 0]
|
| 105 |
+
klass = nn.Conv2d
|
| 106 |
+
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
| 107 |
+
if self.empty:
|
| 108 |
+
return
|
| 109 |
+
self.norm1 = norm_fn(chout)
|
| 110 |
+
self.rewrite = None
|
| 111 |
+
if rewrite:
|
| 112 |
+
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
| 113 |
+
self.norm2 = norm_fn(2 * chout)
|
| 114 |
+
|
| 115 |
+
self.dconv = None
|
| 116 |
+
if dconv:
|
| 117 |
+
self.dconv = DConv(chout, **dconv_kw)
|
| 118 |
+
|
| 119 |
+
def forward(self, x, inject=None):
|
| 120 |
+
"""
|
| 121 |
+
`inject` is used to inject the result from the time branch into the frequency branch,
|
| 122 |
+
when both have the same stride.
|
| 123 |
+
"""
|
| 124 |
+
if not self.freq and x.dim() == 4:
|
| 125 |
+
B, C, Fr, T = x.shape
|
| 126 |
+
x = x.view(B, -1, T)
|
| 127 |
+
|
| 128 |
+
if not self.freq:
|
| 129 |
+
le = x.shape[-1]
|
| 130 |
+
if not le % self.stride == 0:
|
| 131 |
+
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
| 132 |
+
y = self.conv(x)
|
| 133 |
+
if self.empty:
|
| 134 |
+
return y
|
| 135 |
+
if inject is not None:
|
| 136 |
+
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
|
| 137 |
+
if inject.dim() == 3 and y.dim() == 4:
|
| 138 |
+
inject = inject[:, :, None]
|
| 139 |
+
y = y + inject
|
| 140 |
+
y = F.gelu(self.norm1(y))
|
| 141 |
+
if self.dconv:
|
| 142 |
+
if self.freq:
|
| 143 |
+
B, C, Fr, T = y.shape
|
| 144 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
| 145 |
+
y = self.dconv(y)
|
| 146 |
+
if self.freq:
|
| 147 |
+
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
| 148 |
+
if self.rewrite:
|
| 149 |
+
z = self.norm2(self.rewrite(y))
|
| 150 |
+
z = F.glu(z, dim=1)
|
| 151 |
+
else:
|
| 152 |
+
z = y
|
| 153 |
+
return z
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class MultiWrap(nn.Module):
|
| 157 |
+
"""
|
| 158 |
+
Takes one layer and replicate it N times. each replica will act
|
| 159 |
+
on a frequency band. All is done so that if the N replica have the same weights,
|
| 160 |
+
then this is exactly equivalent to applying the original module on all frequencies.
|
| 161 |
+
|
| 162 |
+
This is a bit over-engineered to avoid edge artifacts when splitting
|
| 163 |
+
the frequency bands, but it is possible the naive implementation would work as well...
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, layer, split_ratios):
|
| 167 |
+
"""
|
| 168 |
+
Args:
|
| 169 |
+
layer: module to clone, must be either HEncLayer or HDecLayer.
|
| 170 |
+
split_ratios: list of float indicating which ratio to keep for each band.
|
| 171 |
+
"""
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.split_ratios = split_ratios
|
| 174 |
+
self.layers = nn.ModuleList()
|
| 175 |
+
self.conv = isinstance(layer, HEncLayer)
|
| 176 |
+
assert not layer.norm
|
| 177 |
+
assert layer.freq
|
| 178 |
+
assert layer.pad
|
| 179 |
+
if not self.conv:
|
| 180 |
+
assert not layer.context_freq
|
| 181 |
+
for k in range(len(split_ratios) + 1):
|
| 182 |
+
lay = deepcopy(layer)
|
| 183 |
+
if self.conv:
|
| 184 |
+
lay.conv.padding = (0, 0)
|
| 185 |
+
else:
|
| 186 |
+
lay.pad = False
|
| 187 |
+
for m in lay.modules():
|
| 188 |
+
if hasattr(m, "reset_parameters"):
|
| 189 |
+
m.reset_parameters()
|
| 190 |
+
self.layers.append(lay)
|
| 191 |
+
|
| 192 |
+
def forward(self, x, skip=None, length=None):
|
| 193 |
+
B, C, Fr, T = x.shape
|
| 194 |
+
|
| 195 |
+
ratios = list(self.split_ratios) + [1]
|
| 196 |
+
start = 0
|
| 197 |
+
outs = []
|
| 198 |
+
for ratio, layer in zip(ratios, self.layers):
|
| 199 |
+
if self.conv:
|
| 200 |
+
pad = layer.kernel_size // 4
|
| 201 |
+
if ratio == 1:
|
| 202 |
+
limit = Fr
|
| 203 |
+
frames = -1
|
| 204 |
+
else:
|
| 205 |
+
limit = int(round(Fr * ratio))
|
| 206 |
+
le = limit - start
|
| 207 |
+
if start == 0:
|
| 208 |
+
le += pad
|
| 209 |
+
frames = round((le - layer.kernel_size) / layer.stride + 1)
|
| 210 |
+
limit = start + (frames - 1) * layer.stride + layer.kernel_size
|
| 211 |
+
if start == 0:
|
| 212 |
+
limit -= pad
|
| 213 |
+
assert limit - start > 0, (limit, start)
|
| 214 |
+
assert limit <= Fr, (limit, Fr)
|
| 215 |
+
y = x[:, :, start:limit, :]
|
| 216 |
+
if start == 0:
|
| 217 |
+
y = F.pad(y, (0, 0, pad, 0))
|
| 218 |
+
if ratio == 1:
|
| 219 |
+
y = F.pad(y, (0, 0, 0, pad))
|
| 220 |
+
outs.append(layer(y))
|
| 221 |
+
start = limit - layer.kernel_size + layer.stride
|
| 222 |
+
else:
|
| 223 |
+
if ratio == 1:
|
| 224 |
+
limit = Fr
|
| 225 |
+
else:
|
| 226 |
+
limit = int(round(Fr * ratio))
|
| 227 |
+
last = layer.last
|
| 228 |
+
layer.last = True
|
| 229 |
+
|
| 230 |
+
y = x[:, :, start:limit]
|
| 231 |
+
s = skip[:, :, start:limit]
|
| 232 |
+
out, _ = layer(y, s, None)
|
| 233 |
+
if outs:
|
| 234 |
+
outs[-1][:, :, -layer.stride :] += out[:, :, : layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)
|
| 235 |
+
out = out[:, :, layer.stride :]
|
| 236 |
+
if ratio == 1:
|
| 237 |
+
out = out[:, :, : -layer.stride // 2, :]
|
| 238 |
+
if start == 0:
|
| 239 |
+
out = out[:, :, layer.stride // 2 :, :]
|
| 240 |
+
outs.append(out)
|
| 241 |
+
layer.last = last
|
| 242 |
+
start = limit
|
| 243 |
+
out = torch.cat(outs, dim=2)
|
| 244 |
+
if not self.conv and not last:
|
| 245 |
+
out = F.gelu(out)
|
| 246 |
+
if self.conv:
|
| 247 |
+
return out
|
| 248 |
+
else:
|
| 249 |
+
return out, None
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class HDecLayer(nn.Module):
|
| 253 |
+
def __init__(
|
| 254 |
+
self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, context_freq=True, rewrite=True
|
| 255 |
+
):
|
| 256 |
+
"""
|
| 257 |
+
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
|
| 258 |
+
"""
|
| 259 |
+
super().__init__()
|
| 260 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 261 |
+
if norm:
|
| 262 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 263 |
+
if pad:
|
| 264 |
+
pad = kernel_size // 4
|
| 265 |
+
else:
|
| 266 |
+
pad = 0
|
| 267 |
+
self.pad = pad
|
| 268 |
+
self.last = last
|
| 269 |
+
self.freq = freq
|
| 270 |
+
self.chin = chin
|
| 271 |
+
self.empty = empty
|
| 272 |
+
self.stride = stride
|
| 273 |
+
self.kernel_size = kernel_size
|
| 274 |
+
self.norm = norm
|
| 275 |
+
self.context_freq = context_freq
|
| 276 |
+
klass = nn.Conv1d
|
| 277 |
+
klass_tr = nn.ConvTranspose1d
|
| 278 |
+
if freq:
|
| 279 |
+
kernel_size = [kernel_size, 1]
|
| 280 |
+
stride = [stride, 1]
|
| 281 |
+
klass = nn.Conv2d
|
| 282 |
+
klass_tr = nn.ConvTranspose2d
|
| 283 |
+
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
| 284 |
+
self.norm2 = norm_fn(chout)
|
| 285 |
+
if self.empty:
|
| 286 |
+
return
|
| 287 |
+
self.rewrite = None
|
| 288 |
+
if rewrite:
|
| 289 |
+
if context_freq:
|
| 290 |
+
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
| 291 |
+
else:
|
| 292 |
+
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, [0, context])
|
| 293 |
+
self.norm1 = norm_fn(2 * chin)
|
| 294 |
+
|
| 295 |
+
self.dconv = None
|
| 296 |
+
if dconv:
|
| 297 |
+
self.dconv = DConv(chin, **dconv_kw)
|
| 298 |
+
|
| 299 |
+
def forward(self, x, skip, length):
|
| 300 |
+
if self.freq and x.dim() == 3:
|
| 301 |
+
B, C, T = x.shape
|
| 302 |
+
x = x.view(B, self.chin, -1, T)
|
| 303 |
+
|
| 304 |
+
if not self.empty:
|
| 305 |
+
x = x + skip
|
| 306 |
+
|
| 307 |
+
if self.rewrite:
|
| 308 |
+
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
| 309 |
+
else:
|
| 310 |
+
y = x
|
| 311 |
+
if self.dconv:
|
| 312 |
+
if self.freq:
|
| 313 |
+
B, C, Fr, T = y.shape
|
| 314 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
| 315 |
+
y = self.dconv(y)
|
| 316 |
+
if self.freq:
|
| 317 |
+
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
| 318 |
+
else:
|
| 319 |
+
y = x
|
| 320 |
+
assert skip is None
|
| 321 |
+
z = self.norm2(self.conv_tr(y))
|
| 322 |
+
if self.freq:
|
| 323 |
+
if self.pad:
|
| 324 |
+
z = z[..., self.pad : -self.pad, :]
|
| 325 |
+
else:
|
| 326 |
+
z = z[..., self.pad : self.pad + length]
|
| 327 |
+
assert z.shape[-1] == length, (z.shape[-1], length)
|
| 328 |
+
if not self.last:
|
| 329 |
+
z = F.gelu(z)
|
| 330 |
+
return z, y
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class HDemucs(nn.Module):
|
| 334 |
+
"""
|
| 335 |
+
Spectrogram and hybrid Demucs model.
|
| 336 |
+
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
| 337 |
+
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
| 338 |
+
Frequency layers can still access information across time steps thanks to the DConv residual.
|
| 339 |
+
|
| 340 |
+
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
| 341 |
+
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
| 342 |
+
|
| 343 |
+
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
| 344 |
+
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
| 345 |
+
Open Unmix implementation [Stoter et al. 2019].
|
| 346 |
+
|
| 347 |
+
The loss is always on the temporal domain, by backpropagating through the above
|
| 348 |
+
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
| 349 |
+
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
| 350 |
+
contribution, without changing the one from the waveform, which will lead to worse performance.
|
| 351 |
+
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
| 352 |
+
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
| 353 |
+
hybrid models.
|
| 354 |
+
|
| 355 |
+
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
| 356 |
+
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
| 357 |
+
|
| 358 |
+
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
@capture_init
|
| 362 |
+
def __init__(
|
| 363 |
+
self,
|
| 364 |
+
sources,
|
| 365 |
+
# Channels
|
| 366 |
+
audio_channels=2,
|
| 367 |
+
channels=48,
|
| 368 |
+
channels_time=None,
|
| 369 |
+
growth=2,
|
| 370 |
+
# STFT
|
| 371 |
+
nfft=4096,
|
| 372 |
+
wiener_iters=0,
|
| 373 |
+
end_iters=0,
|
| 374 |
+
wiener_residual=False,
|
| 375 |
+
cac=True,
|
| 376 |
+
# Main structure
|
| 377 |
+
depth=6,
|
| 378 |
+
rewrite=True,
|
| 379 |
+
hybrid=True,
|
| 380 |
+
hybrid_old=False,
|
| 381 |
+
# Frequency branch
|
| 382 |
+
multi_freqs=None,
|
| 383 |
+
multi_freqs_depth=2,
|
| 384 |
+
freq_emb=0.2,
|
| 385 |
+
emb_scale=10,
|
| 386 |
+
emb_smooth=True,
|
| 387 |
+
# Convolutions
|
| 388 |
+
kernel_size=8,
|
| 389 |
+
time_stride=2,
|
| 390 |
+
stride=4,
|
| 391 |
+
context=1,
|
| 392 |
+
context_enc=0,
|
| 393 |
+
# Normalization
|
| 394 |
+
norm_starts=4,
|
| 395 |
+
norm_groups=4,
|
| 396 |
+
# DConv residual branch
|
| 397 |
+
dconv_mode=1,
|
| 398 |
+
dconv_depth=2,
|
| 399 |
+
dconv_comp=4,
|
| 400 |
+
dconv_attn=4,
|
| 401 |
+
dconv_lstm=4,
|
| 402 |
+
dconv_init=1e-4,
|
| 403 |
+
# Weight init
|
| 404 |
+
rescale=0.1,
|
| 405 |
+
# Metadata
|
| 406 |
+
samplerate=44100,
|
| 407 |
+
segment=4 * 10,
|
| 408 |
+
):
|
| 409 |
+
"""
|
| 410 |
+
Args:
|
| 411 |
+
sources (list[str]): list of source names.
|
| 412 |
+
audio_channels (int): input/output audio channels.
|
| 413 |
+
channels (int): initial number of hidden channels.
|
| 414 |
+
channels_time: if not None, use a different `channels` value for the time branch.
|
| 415 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
| 416 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
| 417 |
+
various shape parameters and will not work out of the box for hybrid models.
|
| 418 |
+
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
| 419 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
| 420 |
+
wiener_residual: add residual source before wiener filtering.
|
| 421 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
| 422 |
+
in input and output. no further processing is done before ISTFT.
|
| 423 |
+
depth (int): number of layers in the encoder and in the decoder.
|
| 424 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
| 425 |
+
hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
|
| 426 |
+
hybrid_old: some models trained for MDX had a padding bug. This replicates
|
| 427 |
+
this bug to avoid retraining them.
|
| 428 |
+
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
| 429 |
+
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
| 430 |
+
layers will be wrapped.
|
| 431 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
| 432 |
+
the actual value controls the weight of the embedding.
|
| 433 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
| 434 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
| 435 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
| 436 |
+
stride: stride for encoder and decoder layers.
|
| 437 |
+
time_stride: stride for the final time layer, after the merge.
|
| 438 |
+
context: context for 1x1 conv in the decoder.
|
| 439 |
+
context_enc: context for 1x1 conv in the encoder.
|
| 440 |
+
norm_starts: layer at which group norm starts being used.
|
| 441 |
+
decoder layers are numbered in reverse order.
|
| 442 |
+
norm_groups: number of groups for group norm.
|
| 443 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
| 444 |
+
dconv_depth: depth of residual DConv branch.
|
| 445 |
+
dconv_comp: compression of DConv branch.
|
| 446 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
| 447 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
| 448 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
| 449 |
+
rescale: weight recaling trick
|
| 450 |
+
|
| 451 |
+
"""
|
| 452 |
+
super().__init__()
|
| 453 |
+
|
| 454 |
+
self.cac = cac
|
| 455 |
+
self.wiener_residual = wiener_residual
|
| 456 |
+
self.audio_channels = audio_channels
|
| 457 |
+
self.sources = sources
|
| 458 |
+
self.kernel_size = kernel_size
|
| 459 |
+
self.context = context
|
| 460 |
+
self.stride = stride
|
| 461 |
+
self.depth = depth
|
| 462 |
+
self.channels = channels
|
| 463 |
+
self.samplerate = samplerate
|
| 464 |
+
self.segment = segment
|
| 465 |
+
|
| 466 |
+
self.nfft = nfft
|
| 467 |
+
self.hop_length = nfft // 4
|
| 468 |
+
self.wiener_iters = wiener_iters
|
| 469 |
+
self.end_iters = end_iters
|
| 470 |
+
self.freq_emb = None
|
| 471 |
+
self.hybrid = hybrid
|
| 472 |
+
self.hybrid_old = hybrid_old
|
| 473 |
+
if hybrid_old:
|
| 474 |
+
assert hybrid, "hybrid_old must come with hybrid=True"
|
| 475 |
+
if hybrid:
|
| 476 |
+
assert wiener_iters == end_iters
|
| 477 |
+
|
| 478 |
+
self.encoder = nn.ModuleList()
|
| 479 |
+
self.decoder = nn.ModuleList()
|
| 480 |
+
|
| 481 |
+
if hybrid:
|
| 482 |
+
self.tencoder = nn.ModuleList()
|
| 483 |
+
self.tdecoder = nn.ModuleList()
|
| 484 |
+
|
| 485 |
+
chin = audio_channels
|
| 486 |
+
chin_z = chin # number of channels for the freq branch
|
| 487 |
+
if self.cac:
|
| 488 |
+
chin_z *= 2
|
| 489 |
+
chout = channels_time or channels
|
| 490 |
+
chout_z = channels
|
| 491 |
+
freqs = nfft // 2
|
| 492 |
+
|
| 493 |
+
for index in range(depth):
|
| 494 |
+
lstm = index >= dconv_lstm
|
| 495 |
+
attn = index >= dconv_attn
|
| 496 |
+
norm = index >= norm_starts
|
| 497 |
+
freq = freqs > 1
|
| 498 |
+
stri = stride
|
| 499 |
+
ker = kernel_size
|
| 500 |
+
if not freq:
|
| 501 |
+
assert freqs == 1
|
| 502 |
+
ker = time_stride * 2
|
| 503 |
+
stri = time_stride
|
| 504 |
+
|
| 505 |
+
pad = True
|
| 506 |
+
last_freq = False
|
| 507 |
+
if freq and freqs <= kernel_size:
|
| 508 |
+
ker = freqs
|
| 509 |
+
pad = False
|
| 510 |
+
last_freq = True
|
| 511 |
+
|
| 512 |
+
kw = {
|
| 513 |
+
"kernel_size": ker,
|
| 514 |
+
"stride": stri,
|
| 515 |
+
"freq": freq,
|
| 516 |
+
"pad": pad,
|
| 517 |
+
"norm": norm,
|
| 518 |
+
"rewrite": rewrite,
|
| 519 |
+
"norm_groups": norm_groups,
|
| 520 |
+
"dconv_kw": {"lstm": lstm, "attn": attn, "depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
|
| 521 |
+
}
|
| 522 |
+
kwt = dict(kw)
|
| 523 |
+
kwt["freq"] = 0
|
| 524 |
+
kwt["kernel_size"] = kernel_size
|
| 525 |
+
kwt["stride"] = stride
|
| 526 |
+
kwt["pad"] = True
|
| 527 |
+
kw_dec = dict(kw)
|
| 528 |
+
multi = False
|
| 529 |
+
if multi_freqs and index < multi_freqs_depth:
|
| 530 |
+
multi = True
|
| 531 |
+
kw_dec["context_freq"] = False
|
| 532 |
+
|
| 533 |
+
if last_freq:
|
| 534 |
+
chout_z = max(chout, chout_z)
|
| 535 |
+
chout = chout_z
|
| 536 |
+
|
| 537 |
+
enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
|
| 538 |
+
if hybrid and freq:
|
| 539 |
+
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
|
| 540 |
+
self.tencoder.append(tenc)
|
| 541 |
+
|
| 542 |
+
if multi:
|
| 543 |
+
enc = MultiWrap(enc, multi_freqs)
|
| 544 |
+
self.encoder.append(enc)
|
| 545 |
+
if index == 0:
|
| 546 |
+
chin = self.audio_channels * len(self.sources)
|
| 547 |
+
chin_z = chin
|
| 548 |
+
if self.cac:
|
| 549 |
+
chin_z *= 2
|
| 550 |
+
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
|
| 551 |
+
if multi:
|
| 552 |
+
dec = MultiWrap(dec, multi_freqs)
|
| 553 |
+
if hybrid and freq:
|
| 554 |
+
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
|
| 555 |
+
self.tdecoder.insert(0, tdec)
|
| 556 |
+
self.decoder.insert(0, dec)
|
| 557 |
+
|
| 558 |
+
chin = chout
|
| 559 |
+
chin_z = chout_z
|
| 560 |
+
chout = int(growth * chout)
|
| 561 |
+
chout_z = int(growth * chout_z)
|
| 562 |
+
if freq:
|
| 563 |
+
if freqs <= kernel_size:
|
| 564 |
+
freqs = 1
|
| 565 |
+
else:
|
| 566 |
+
freqs //= stride
|
| 567 |
+
if index == 0 and freq_emb:
|
| 568 |
+
self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
| 569 |
+
self.freq_emb_scale = freq_emb
|
| 570 |
+
|
| 571 |
+
if rescale:
|
| 572 |
+
rescale_module(self, reference=rescale)
|
| 573 |
+
|
| 574 |
+
def _spec(self, x):
|
| 575 |
+
hl = self.hop_length
|
| 576 |
+
nfft = self.nfft
|
| 577 |
+
x0 = x # noqa
|
| 578 |
+
|
| 579 |
+
if self.hybrid:
|
| 580 |
+
# We re-pad the signal in order to keep the property
|
| 581 |
+
# that the size of the output is exactly the size of the input
|
| 582 |
+
# divided by the stride (here hop_length), when divisible.
|
| 583 |
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
| 584 |
+
# which is not supported by torch.stft.
|
| 585 |
+
# Having all convolution operations follow this convention allow to easily
|
| 586 |
+
# align the time and frequency branches later on.
|
| 587 |
+
assert hl == nfft // 4
|
| 588 |
+
le = int(math.ceil(x.shape[-1] / hl))
|
| 589 |
+
pad = hl // 2 * 3
|
| 590 |
+
if not self.hybrid_old:
|
| 591 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
|
| 592 |
+
else:
|
| 593 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
|
| 594 |
+
|
| 595 |
+
z = spectro(x, nfft, hl)[..., :-1, :]
|
| 596 |
+
if self.hybrid:
|
| 597 |
+
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
| 598 |
+
z = z[..., 2 : 2 + le]
|
| 599 |
+
return z
|
| 600 |
+
|
| 601 |
+
def _ispec(self, z, length=None, scale=0):
|
| 602 |
+
hl = self.hop_length // (4**scale)
|
| 603 |
+
z = F.pad(z, (0, 0, 0, 1))
|
| 604 |
+
if self.hybrid:
|
| 605 |
+
z = F.pad(z, (2, 2))
|
| 606 |
+
pad = hl // 2 * 3
|
| 607 |
+
if not self.hybrid_old:
|
| 608 |
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
| 609 |
+
else:
|
| 610 |
+
le = hl * int(math.ceil(length / hl))
|
| 611 |
+
x = ispectro(z, hl, length=le)
|
| 612 |
+
if not self.hybrid_old:
|
| 613 |
+
x = x[..., pad : pad + length]
|
| 614 |
+
else:
|
| 615 |
+
x = x[..., :length]
|
| 616 |
+
else:
|
| 617 |
+
x = ispectro(z, hl, length)
|
| 618 |
+
return x
|
| 619 |
+
|
| 620 |
+
def _magnitude(self, z):
|
| 621 |
+
# return the magnitude of the spectrogram, except when cac is True,
|
| 622 |
+
# in which case we just move the complex dimension to the channel one.
|
| 623 |
+
if self.cac:
|
| 624 |
+
B, C, Fr, T = z.shape
|
| 625 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
| 626 |
+
m = m.reshape(B, C * 2, Fr, T)
|
| 627 |
+
else:
|
| 628 |
+
m = z.abs()
|
| 629 |
+
return m
|
| 630 |
+
|
| 631 |
+
def _mask(self, z, m):
|
| 632 |
+
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
| 633 |
+
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
| 634 |
+
niters = self.wiener_iters
|
| 635 |
+
if self.cac:
|
| 636 |
+
B, S, C, Fr, T = m.shape
|
| 637 |
+
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
| 638 |
+
out = torch.view_as_complex(out.contiguous())
|
| 639 |
+
return out
|
| 640 |
+
if self.training:
|
| 641 |
+
niters = self.end_iters
|
| 642 |
+
if niters < 0:
|
| 643 |
+
z = z[:, None]
|
| 644 |
+
return z / (1e-8 + z.abs()) * m
|
| 645 |
+
else:
|
| 646 |
+
return self._wiener(m, z, niters)
|
| 647 |
+
|
| 648 |
+
def _wiener(self, mag_out, mix_stft, niters):
|
| 649 |
+
# apply wiener filtering from OpenUnmix.
|
| 650 |
+
init = mix_stft.dtype
|
| 651 |
+
wiener_win_len = 300
|
| 652 |
+
residual = self.wiener_residual
|
| 653 |
+
|
| 654 |
+
B, S, C, Fq, T = mag_out.shape
|
| 655 |
+
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
| 656 |
+
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
| 657 |
+
|
| 658 |
+
outs = []
|
| 659 |
+
for sample in range(B):
|
| 660 |
+
pos = 0
|
| 661 |
+
out = []
|
| 662 |
+
for pos in range(0, T, wiener_win_len):
|
| 663 |
+
frame = slice(pos, pos + wiener_win_len)
|
| 664 |
+
z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
|
| 665 |
+
out.append(z_out.transpose(-1, -2))
|
| 666 |
+
outs.append(torch.cat(out, dim=0))
|
| 667 |
+
out = torch.view_as_complex(torch.stack(outs, 0))
|
| 668 |
+
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
| 669 |
+
if residual:
|
| 670 |
+
out = out[:, :-1]
|
| 671 |
+
assert list(out.shape) == [B, S, C, Fq, T]
|
| 672 |
+
return out.to(init)
|
| 673 |
+
|
| 674 |
+
def forward(self, mix):
|
| 675 |
+
x = mix
|
| 676 |
+
length = x.shape[-1]
|
| 677 |
+
|
| 678 |
+
z = self._spec(mix)
|
| 679 |
+
mag = self._magnitude(z).to(mix.device)
|
| 680 |
+
x = mag
|
| 681 |
+
|
| 682 |
+
B, C, Fq, T = x.shape
|
| 683 |
+
|
| 684 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
| 685 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
| 686 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
| 687 |
+
x = (x - mean) / (1e-5 + std)
|
| 688 |
+
# x will be the freq. branch input.
|
| 689 |
+
|
| 690 |
+
if self.hybrid:
|
| 691 |
+
# Prepare the time branch input.
|
| 692 |
+
xt = mix
|
| 693 |
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
| 694 |
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
| 695 |
+
xt = (xt - meant) / (1e-5 + stdt)
|
| 696 |
+
|
| 697 |
+
# okay, this is a giant mess I know...
|
| 698 |
+
saved = [] # skip connections, freq.
|
| 699 |
+
saved_t = [] # skip connections, time.
|
| 700 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
| 701 |
+
lengths_t = [] # saved lengths for time branch.
|
| 702 |
+
for idx, encode in enumerate(self.encoder):
|
| 703 |
+
lengths.append(x.shape[-1])
|
| 704 |
+
inject = None
|
| 705 |
+
if self.hybrid and idx < len(self.tencoder):
|
| 706 |
+
# we have not yet merged branches.
|
| 707 |
+
lengths_t.append(xt.shape[-1])
|
| 708 |
+
tenc = self.tencoder[idx]
|
| 709 |
+
xt = tenc(xt)
|
| 710 |
+
if not tenc.empty:
|
| 711 |
+
# save for skip connection
|
| 712 |
+
saved_t.append(xt)
|
| 713 |
+
else:
|
| 714 |
+
# tenc contains just the first conv., so that now time and freq.
|
| 715 |
+
# branches have the same shape and can be merged.
|
| 716 |
+
inject = xt
|
| 717 |
+
x = encode(x, inject)
|
| 718 |
+
if idx == 0 and self.freq_emb is not None:
|
| 719 |
+
# add frequency embedding to allow for non equivariant convolutions
|
| 720 |
+
# over the frequency axis.
|
| 721 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
| 722 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
| 723 |
+
x = x + self.freq_emb_scale * emb
|
| 724 |
+
|
| 725 |
+
saved.append(x)
|
| 726 |
+
|
| 727 |
+
x = torch.zeros_like(x)
|
| 728 |
+
if self.hybrid:
|
| 729 |
+
xt = torch.zeros_like(x)
|
| 730 |
+
# initialize everything to zero (signal will go through u-net skips).
|
| 731 |
+
|
| 732 |
+
for idx, decode in enumerate(self.decoder):
|
| 733 |
+
skip = saved.pop(-1)
|
| 734 |
+
x, pre = decode(x, skip, lengths.pop(-1))
|
| 735 |
+
# `pre` contains the output just before final transposed convolution,
|
| 736 |
+
# which is used when the freq. and time branch separate.
|
| 737 |
+
|
| 738 |
+
if self.hybrid:
|
| 739 |
+
offset = self.depth - len(self.tdecoder)
|
| 740 |
+
if self.hybrid and idx >= offset:
|
| 741 |
+
tdec = self.tdecoder[idx - offset]
|
| 742 |
+
length_t = lengths_t.pop(-1)
|
| 743 |
+
if tdec.empty:
|
| 744 |
+
assert pre.shape[2] == 1, pre.shape
|
| 745 |
+
pre = pre[:, :, 0]
|
| 746 |
+
xt, _ = tdec(pre, None, length_t)
|
| 747 |
+
else:
|
| 748 |
+
skip = saved_t.pop(-1)
|
| 749 |
+
xt, _ = tdec(xt, skip, length_t)
|
| 750 |
+
|
| 751 |
+
# Let's make sure we used all stored skip connections.
|
| 752 |
+
assert len(saved) == 0
|
| 753 |
+
assert len(lengths_t) == 0
|
| 754 |
+
assert len(saved_t) == 0
|
| 755 |
+
|
| 756 |
+
S = len(self.sources)
|
| 757 |
+
x = x.view(B, S, -1, Fq, T)
|
| 758 |
+
x = x * std[:, None] + mean[:, None]
|
| 759 |
+
|
| 760 |
+
# to cpu as non-cuda GPUs don't support complex numbers
|
| 761 |
+
# demucs issue #435 ##432
|
| 762 |
+
# NOTE: in this case z already is on cpu
|
| 763 |
+
# TODO: remove this when mps supports complex numbers
|
| 764 |
+
|
| 765 |
+
device_type = x.device.type
|
| 766 |
+
device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
|
| 767 |
+
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
| 768 |
+
|
| 769 |
+
if x_is_other_gpu:
|
| 770 |
+
x = x.cpu()
|
| 771 |
+
|
| 772 |
+
zout = self._mask(z, x)
|
| 773 |
+
x = self._ispec(zout, length)
|
| 774 |
+
|
| 775 |
+
# back to other device
|
| 776 |
+
if x_is_other_gpu:
|
| 777 |
+
x = x.to(device_load)
|
| 778 |
+
|
| 779 |
+
if self.hybrid:
|
| 780 |
+
xt = xt.view(B, S, -1, length)
|
| 781 |
+
xt = xt * stdt[:, None] + meant[:, None]
|
| 782 |
+
x = xt + x
|
| 783 |
+
return x
|
audio_separator/separator/uvr_lib_v5/demucs/htdemucs.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# First author is Simon Rouard.
|
| 7 |
+
"""
|
| 8 |
+
This code contains the spectrogram and Hybrid version of Demucs.
|
| 9 |
+
"""
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
from .filtering import wiener
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
from fractions import Fraction
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
|
| 19 |
+
from .transformer import CrossTransformerEncoder
|
| 20 |
+
|
| 21 |
+
from .demucs import rescale_module
|
| 22 |
+
from .states import capture_init
|
| 23 |
+
from .spec import spectro, ispectro
|
| 24 |
+
from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class HTDemucs(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
Spectrogram and hybrid Demucs model.
|
| 30 |
+
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
| 31 |
+
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
| 32 |
+
Frequency layers can still access information across time steps thanks to the DConv residual.
|
| 33 |
+
|
| 34 |
+
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
| 35 |
+
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
| 36 |
+
|
| 37 |
+
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
| 38 |
+
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
| 39 |
+
Open Unmix implementation [Stoter et al. 2019].
|
| 40 |
+
|
| 41 |
+
The loss is always on the temporal domain, by backpropagating through the above
|
| 42 |
+
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
| 43 |
+
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
| 44 |
+
contribution, without changing the one from the waveform, which will lead to worse performance.
|
| 45 |
+
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
| 46 |
+
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
| 47 |
+
hybrid models.
|
| 48 |
+
|
| 49 |
+
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
| 50 |
+
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
| 51 |
+
|
| 52 |
+
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
@capture_init
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
sources,
|
| 59 |
+
# Channels
|
| 60 |
+
audio_channels=2,
|
| 61 |
+
channels=48,
|
| 62 |
+
channels_time=None,
|
| 63 |
+
growth=2,
|
| 64 |
+
# STFT
|
| 65 |
+
nfft=4096,
|
| 66 |
+
wiener_iters=0,
|
| 67 |
+
end_iters=0,
|
| 68 |
+
wiener_residual=False,
|
| 69 |
+
cac=True,
|
| 70 |
+
# Main structure
|
| 71 |
+
depth=4,
|
| 72 |
+
rewrite=True,
|
| 73 |
+
# Frequency branch
|
| 74 |
+
multi_freqs=None,
|
| 75 |
+
multi_freqs_depth=3,
|
| 76 |
+
freq_emb=0.2,
|
| 77 |
+
emb_scale=10,
|
| 78 |
+
emb_smooth=True,
|
| 79 |
+
# Convolutions
|
| 80 |
+
kernel_size=8,
|
| 81 |
+
time_stride=2,
|
| 82 |
+
stride=4,
|
| 83 |
+
context=1,
|
| 84 |
+
context_enc=0,
|
| 85 |
+
# Normalization
|
| 86 |
+
norm_starts=4,
|
| 87 |
+
norm_groups=4,
|
| 88 |
+
# DConv residual branch
|
| 89 |
+
dconv_mode=1,
|
| 90 |
+
dconv_depth=2,
|
| 91 |
+
dconv_comp=8,
|
| 92 |
+
dconv_init=1e-3,
|
| 93 |
+
# Before the Transformer
|
| 94 |
+
bottom_channels=0,
|
| 95 |
+
# Transformer
|
| 96 |
+
t_layers=5,
|
| 97 |
+
t_emb="sin",
|
| 98 |
+
t_hidden_scale=4.0,
|
| 99 |
+
t_heads=8,
|
| 100 |
+
t_dropout=0.0,
|
| 101 |
+
t_max_positions=10000,
|
| 102 |
+
t_norm_in=True,
|
| 103 |
+
t_norm_in_group=False,
|
| 104 |
+
t_group_norm=False,
|
| 105 |
+
t_norm_first=True,
|
| 106 |
+
t_norm_out=True,
|
| 107 |
+
t_max_period=10000.0,
|
| 108 |
+
t_weight_decay=0.0,
|
| 109 |
+
t_lr=None,
|
| 110 |
+
t_layer_scale=True,
|
| 111 |
+
t_gelu=True,
|
| 112 |
+
t_weight_pos_embed=1.0,
|
| 113 |
+
t_sin_random_shift=0,
|
| 114 |
+
t_cape_mean_normalize=True,
|
| 115 |
+
t_cape_augment=True,
|
| 116 |
+
t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
|
| 117 |
+
t_sparse_self_attn=False,
|
| 118 |
+
t_sparse_cross_attn=False,
|
| 119 |
+
t_mask_type="diag",
|
| 120 |
+
t_mask_random_seed=42,
|
| 121 |
+
t_sparse_attn_window=500,
|
| 122 |
+
t_global_window=100,
|
| 123 |
+
t_sparsity=0.95,
|
| 124 |
+
t_auto_sparsity=False,
|
| 125 |
+
# ------ Particuliar parameters
|
| 126 |
+
t_cross_first=False,
|
| 127 |
+
# Weight init
|
| 128 |
+
rescale=0.1,
|
| 129 |
+
# Metadata
|
| 130 |
+
samplerate=44100,
|
| 131 |
+
segment=10,
|
| 132 |
+
use_train_segment=True,
|
| 133 |
+
):
|
| 134 |
+
"""
|
| 135 |
+
Args:
|
| 136 |
+
sources (list[str]): list of source names.
|
| 137 |
+
audio_channels (int): input/output audio channels.
|
| 138 |
+
channels (int): initial number of hidden channels.
|
| 139 |
+
channels_time: if not None, use a different `channels` value for the time branch.
|
| 140 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
| 141 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
| 142 |
+
various shape parameters and will not work out of the box for hybrid models.
|
| 143 |
+
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
| 144 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
| 145 |
+
wiener_residual: add residual source before wiener filtering.
|
| 146 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
| 147 |
+
in input and output. no further processing is done before ISTFT.
|
| 148 |
+
depth (int): number of layers in the encoder and in the decoder.
|
| 149 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
| 150 |
+
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
| 151 |
+
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
| 152 |
+
layers will be wrapped.
|
| 153 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
| 154 |
+
the actual value controls the weight of the embedding.
|
| 155 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
| 156 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
| 157 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
| 158 |
+
stride: stride for encoder and decoder layers.
|
| 159 |
+
time_stride: stride for the final time layer, after the merge.
|
| 160 |
+
context: context for 1x1 conv in the decoder.
|
| 161 |
+
context_enc: context for 1x1 conv in the encoder.
|
| 162 |
+
norm_starts: layer at which group norm starts being used.
|
| 163 |
+
decoder layers are numbered in reverse order.
|
| 164 |
+
norm_groups: number of groups for group norm.
|
| 165 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
| 166 |
+
dconv_depth: depth of residual DConv branch.
|
| 167 |
+
dconv_comp: compression of DConv branch.
|
| 168 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
| 169 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
| 170 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
| 171 |
+
bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
|
| 172 |
+
transformer in order to change the number of channels
|
| 173 |
+
t_layers: number of layers in each branch (waveform and spec) of the transformer
|
| 174 |
+
t_emb: "sin", "cape" or "scaled"
|
| 175 |
+
t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
|
| 176 |
+
for instance if C = 384 (the number of channels in the transformer) and
|
| 177 |
+
t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
|
| 178 |
+
384 * 4 = 1536
|
| 179 |
+
t_heads: number of heads for the transformer
|
| 180 |
+
t_dropout: dropout in the transformer
|
| 181 |
+
t_max_positions: max_positions for the "scaled" positional embedding, only
|
| 182 |
+
useful if t_emb="scaled"
|
| 183 |
+
t_norm_in: (bool) norm before addinf positional embedding and getting into the
|
| 184 |
+
transformer layers
|
| 185 |
+
t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
|
| 186 |
+
timesteps (GroupNorm with group=1)
|
| 187 |
+
t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
|
| 188 |
+
timesteps (GroupNorm with group=1)
|
| 189 |
+
t_norm_first: (bool) if True the norm is before the attention and before the FFN
|
| 190 |
+
t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
|
| 191 |
+
t_max_period: (float) denominator in the sinusoidal embedding expression
|
| 192 |
+
t_weight_decay: (float) weight decay for the transformer
|
| 193 |
+
t_lr: (float) specific learning rate for the transformer
|
| 194 |
+
t_layer_scale: (bool) Layer Scale for the transformer
|
| 195 |
+
t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
|
| 196 |
+
t_weight_pos_embed: (float) weighting of the positional embedding
|
| 197 |
+
t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
|
| 198 |
+
see: https://arxiv.org/abs/2106.03143
|
| 199 |
+
t_cape_augment: (bool) if t_emb="cape", must be True during training and False
|
| 200 |
+
during the inference, see: https://arxiv.org/abs/2106.03143
|
| 201 |
+
t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
|
| 202 |
+
see: https://arxiv.org/abs/2106.03143
|
| 203 |
+
t_sparse_self_attn: (bool) if True, the self attentions are sparse
|
| 204 |
+
t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
|
| 205 |
+
unless you designed really specific masks)
|
| 206 |
+
t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
|
| 207 |
+
with '_' between: i.e. "diag_jmask_random" (note that this is permutation
|
| 208 |
+
invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
|
| 209 |
+
t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
|
| 210 |
+
that generated the random part of the mask
|
| 211 |
+
t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
|
| 212 |
+
a key (j), the mask is True id |i-j|<=t_sparse_attn_window
|
| 213 |
+
t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
|
| 214 |
+
and mask[:, :t_global_window] will be True
|
| 215 |
+
t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
|
| 216 |
+
level of the random part of the mask.
|
| 217 |
+
t_cross_first: (bool) if True cross attention is the first layer of the
|
| 218 |
+
transformer (False seems to be better)
|
| 219 |
+
rescale: weight rescaling trick
|
| 220 |
+
use_train_segment: (bool) if True, the actual size that is used during the
|
| 221 |
+
training is used during inference.
|
| 222 |
+
"""
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.cac = cac
|
| 225 |
+
self.wiener_residual = wiener_residual
|
| 226 |
+
self.audio_channels = audio_channels
|
| 227 |
+
self.sources = sources
|
| 228 |
+
self.kernel_size = kernel_size
|
| 229 |
+
self.context = context
|
| 230 |
+
self.stride = stride
|
| 231 |
+
self.depth = depth
|
| 232 |
+
self.bottom_channels = bottom_channels
|
| 233 |
+
self.channels = channels
|
| 234 |
+
self.samplerate = samplerate
|
| 235 |
+
self.segment = segment
|
| 236 |
+
self.use_train_segment = use_train_segment
|
| 237 |
+
self.nfft = nfft
|
| 238 |
+
self.hop_length = nfft // 4
|
| 239 |
+
self.wiener_iters = wiener_iters
|
| 240 |
+
self.end_iters = end_iters
|
| 241 |
+
self.freq_emb = None
|
| 242 |
+
assert wiener_iters == end_iters
|
| 243 |
+
|
| 244 |
+
self.encoder = nn.ModuleList()
|
| 245 |
+
self.decoder = nn.ModuleList()
|
| 246 |
+
|
| 247 |
+
self.tencoder = nn.ModuleList()
|
| 248 |
+
self.tdecoder = nn.ModuleList()
|
| 249 |
+
|
| 250 |
+
chin = audio_channels
|
| 251 |
+
chin_z = chin # number of channels for the freq branch
|
| 252 |
+
if self.cac:
|
| 253 |
+
chin_z *= 2
|
| 254 |
+
chout = channels_time or channels
|
| 255 |
+
chout_z = channels
|
| 256 |
+
freqs = nfft // 2
|
| 257 |
+
|
| 258 |
+
for index in range(depth):
|
| 259 |
+
norm = index >= norm_starts
|
| 260 |
+
freq = freqs > 1
|
| 261 |
+
stri = stride
|
| 262 |
+
ker = kernel_size
|
| 263 |
+
if not freq:
|
| 264 |
+
assert freqs == 1
|
| 265 |
+
ker = time_stride * 2
|
| 266 |
+
stri = time_stride
|
| 267 |
+
|
| 268 |
+
pad = True
|
| 269 |
+
last_freq = False
|
| 270 |
+
if freq and freqs <= kernel_size:
|
| 271 |
+
ker = freqs
|
| 272 |
+
pad = False
|
| 273 |
+
last_freq = True
|
| 274 |
+
|
| 275 |
+
kw = {
|
| 276 |
+
"kernel_size": ker,
|
| 277 |
+
"stride": stri,
|
| 278 |
+
"freq": freq,
|
| 279 |
+
"pad": pad,
|
| 280 |
+
"norm": norm,
|
| 281 |
+
"rewrite": rewrite,
|
| 282 |
+
"norm_groups": norm_groups,
|
| 283 |
+
"dconv_kw": {"depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
|
| 284 |
+
}
|
| 285 |
+
kwt = dict(kw)
|
| 286 |
+
kwt["freq"] = 0
|
| 287 |
+
kwt["kernel_size"] = kernel_size
|
| 288 |
+
kwt["stride"] = stride
|
| 289 |
+
kwt["pad"] = True
|
| 290 |
+
kw_dec = dict(kw)
|
| 291 |
+
multi = False
|
| 292 |
+
if multi_freqs and index < multi_freqs_depth:
|
| 293 |
+
multi = True
|
| 294 |
+
kw_dec["context_freq"] = False
|
| 295 |
+
|
| 296 |
+
if last_freq:
|
| 297 |
+
chout_z = max(chout, chout_z)
|
| 298 |
+
chout = chout_z
|
| 299 |
+
|
| 300 |
+
enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
|
| 301 |
+
if freq:
|
| 302 |
+
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
|
| 303 |
+
self.tencoder.append(tenc)
|
| 304 |
+
|
| 305 |
+
if multi:
|
| 306 |
+
enc = MultiWrap(enc, multi_freqs)
|
| 307 |
+
self.encoder.append(enc)
|
| 308 |
+
if index == 0:
|
| 309 |
+
chin = self.audio_channels * len(self.sources)
|
| 310 |
+
chin_z = chin
|
| 311 |
+
if self.cac:
|
| 312 |
+
chin_z *= 2
|
| 313 |
+
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
|
| 314 |
+
if multi:
|
| 315 |
+
dec = MultiWrap(dec, multi_freqs)
|
| 316 |
+
if freq:
|
| 317 |
+
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
|
| 318 |
+
self.tdecoder.insert(0, tdec)
|
| 319 |
+
self.decoder.insert(0, dec)
|
| 320 |
+
|
| 321 |
+
chin = chout
|
| 322 |
+
chin_z = chout_z
|
| 323 |
+
chout = int(growth * chout)
|
| 324 |
+
chout_z = int(growth * chout_z)
|
| 325 |
+
if freq:
|
| 326 |
+
if freqs <= kernel_size:
|
| 327 |
+
freqs = 1
|
| 328 |
+
else:
|
| 329 |
+
freqs //= stride
|
| 330 |
+
if index == 0 and freq_emb:
|
| 331 |
+
self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
| 332 |
+
self.freq_emb_scale = freq_emb
|
| 333 |
+
|
| 334 |
+
if rescale:
|
| 335 |
+
rescale_module(self, reference=rescale)
|
| 336 |
+
|
| 337 |
+
transformer_channels = channels * growth ** (depth - 1)
|
| 338 |
+
if bottom_channels:
|
| 339 |
+
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
| 340 |
+
self.channel_downsampler = nn.Conv1d(bottom_channels, transformer_channels, 1)
|
| 341 |
+
self.channel_upsampler_t = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
| 342 |
+
self.channel_downsampler_t = nn.Conv1d(bottom_channels, transformer_channels, 1)
|
| 343 |
+
|
| 344 |
+
transformer_channels = bottom_channels
|
| 345 |
+
|
| 346 |
+
if t_layers > 0:
|
| 347 |
+
self.crosstransformer = CrossTransformerEncoder(
|
| 348 |
+
dim=transformer_channels,
|
| 349 |
+
emb=t_emb,
|
| 350 |
+
hidden_scale=t_hidden_scale,
|
| 351 |
+
num_heads=t_heads,
|
| 352 |
+
num_layers=t_layers,
|
| 353 |
+
cross_first=t_cross_first,
|
| 354 |
+
dropout=t_dropout,
|
| 355 |
+
max_positions=t_max_positions,
|
| 356 |
+
norm_in=t_norm_in,
|
| 357 |
+
norm_in_group=t_norm_in_group,
|
| 358 |
+
group_norm=t_group_norm,
|
| 359 |
+
norm_first=t_norm_first,
|
| 360 |
+
norm_out=t_norm_out,
|
| 361 |
+
max_period=t_max_period,
|
| 362 |
+
weight_decay=t_weight_decay,
|
| 363 |
+
lr=t_lr,
|
| 364 |
+
layer_scale=t_layer_scale,
|
| 365 |
+
gelu=t_gelu,
|
| 366 |
+
sin_random_shift=t_sin_random_shift,
|
| 367 |
+
weight_pos_embed=t_weight_pos_embed,
|
| 368 |
+
cape_mean_normalize=t_cape_mean_normalize,
|
| 369 |
+
cape_augment=t_cape_augment,
|
| 370 |
+
cape_glob_loc_scale=t_cape_glob_loc_scale,
|
| 371 |
+
sparse_self_attn=t_sparse_self_attn,
|
| 372 |
+
sparse_cross_attn=t_sparse_cross_attn,
|
| 373 |
+
mask_type=t_mask_type,
|
| 374 |
+
mask_random_seed=t_mask_random_seed,
|
| 375 |
+
sparse_attn_window=t_sparse_attn_window,
|
| 376 |
+
global_window=t_global_window,
|
| 377 |
+
sparsity=t_sparsity,
|
| 378 |
+
auto_sparsity=t_auto_sparsity,
|
| 379 |
+
)
|
| 380 |
+
else:
|
| 381 |
+
self.crosstransformer = None
|
| 382 |
+
|
| 383 |
+
def _spec(self, x):
|
| 384 |
+
hl = self.hop_length
|
| 385 |
+
nfft = self.nfft
|
| 386 |
+
x0 = x # noqa
|
| 387 |
+
|
| 388 |
+
# We re-pad the signal in order to keep the property
|
| 389 |
+
# that the size of the output is exactly the size of the input
|
| 390 |
+
# divided by the stride (here hop_length), when divisible.
|
| 391 |
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
| 392 |
+
# which is not supported by torch.stft.
|
| 393 |
+
# Having all convolution operations follow this convention allow to easily
|
| 394 |
+
# align the time and frequency branches later on.
|
| 395 |
+
assert hl == nfft // 4
|
| 396 |
+
le = int(math.ceil(x.shape[-1] / hl))
|
| 397 |
+
pad = hl // 2 * 3
|
| 398 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
|
| 399 |
+
|
| 400 |
+
z = spectro(x, nfft, hl)[..., :-1, :]
|
| 401 |
+
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
| 402 |
+
z = z[..., 2 : 2 + le]
|
| 403 |
+
return z
|
| 404 |
+
|
| 405 |
+
def _ispec(self, z, length=None, scale=0):
|
| 406 |
+
hl = self.hop_length // (4**scale)
|
| 407 |
+
z = F.pad(z, (0, 0, 0, 1))
|
| 408 |
+
z = F.pad(z, (2, 2))
|
| 409 |
+
pad = hl // 2 * 3
|
| 410 |
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
| 411 |
+
x = ispectro(z, hl, length=le)
|
| 412 |
+
x = x[..., pad : pad + length]
|
| 413 |
+
return x
|
| 414 |
+
|
| 415 |
+
def _magnitude(self, z):
|
| 416 |
+
# return the magnitude of the spectrogram, except when cac is True,
|
| 417 |
+
# in which case we just move the complex dimension to the channel one.
|
| 418 |
+
if self.cac:
|
| 419 |
+
B, C, Fr, T = z.shape
|
| 420 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
| 421 |
+
m = m.reshape(B, C * 2, Fr, T)
|
| 422 |
+
else:
|
| 423 |
+
m = z.abs()
|
| 424 |
+
return m
|
| 425 |
+
|
| 426 |
+
def _mask(self, z, m):
|
| 427 |
+
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
| 428 |
+
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
| 429 |
+
niters = self.wiener_iters
|
| 430 |
+
if self.cac:
|
| 431 |
+
B, S, C, Fr, T = m.shape
|
| 432 |
+
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
| 433 |
+
out = torch.view_as_complex(out.contiguous())
|
| 434 |
+
return out
|
| 435 |
+
if self.training:
|
| 436 |
+
niters = self.end_iters
|
| 437 |
+
if niters < 0:
|
| 438 |
+
z = z[:, None]
|
| 439 |
+
return z / (1e-8 + z.abs()) * m
|
| 440 |
+
else:
|
| 441 |
+
return self._wiener(m, z, niters)
|
| 442 |
+
|
| 443 |
+
def _wiener(self, mag_out, mix_stft, niters):
|
| 444 |
+
# apply wiener filtering from OpenUnmix.
|
| 445 |
+
init = mix_stft.dtype
|
| 446 |
+
wiener_win_len = 300
|
| 447 |
+
residual = self.wiener_residual
|
| 448 |
+
|
| 449 |
+
B, S, C, Fq, T = mag_out.shape
|
| 450 |
+
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
| 451 |
+
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
| 452 |
+
|
| 453 |
+
outs = []
|
| 454 |
+
for sample in range(B):
|
| 455 |
+
pos = 0
|
| 456 |
+
out = []
|
| 457 |
+
for pos in range(0, T, wiener_win_len):
|
| 458 |
+
frame = slice(pos, pos + wiener_win_len)
|
| 459 |
+
z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
|
| 460 |
+
out.append(z_out.transpose(-1, -2))
|
| 461 |
+
outs.append(torch.cat(out, dim=0))
|
| 462 |
+
out = torch.view_as_complex(torch.stack(outs, 0))
|
| 463 |
+
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
| 464 |
+
if residual:
|
| 465 |
+
out = out[:, :-1]
|
| 466 |
+
assert list(out.shape) == [B, S, C, Fq, T]
|
| 467 |
+
return out.to(init)
|
| 468 |
+
|
| 469 |
+
def valid_length(self, length: int):
|
| 470 |
+
"""
|
| 471 |
+
Return a length that is appropriate for evaluation.
|
| 472 |
+
In our case, always return the training length, unless
|
| 473 |
+
it is smaller than the given length, in which case this
|
| 474 |
+
raises an error.
|
| 475 |
+
"""
|
| 476 |
+
if not self.use_train_segment:
|
| 477 |
+
return length
|
| 478 |
+
training_length = int(self.segment * self.samplerate)
|
| 479 |
+
if training_length < length:
|
| 480 |
+
raise ValueError(f"Given length {length} is longer than " f"training length {training_length}")
|
| 481 |
+
return training_length
|
| 482 |
+
|
| 483 |
+
def forward(self, mix):
|
| 484 |
+
length = mix.shape[-1]
|
| 485 |
+
length_pre_pad = None
|
| 486 |
+
if self.use_train_segment:
|
| 487 |
+
if self.training:
|
| 488 |
+
self.segment = Fraction(mix.shape[-1], self.samplerate)
|
| 489 |
+
else:
|
| 490 |
+
training_length = int(self.segment * self.samplerate)
|
| 491 |
+
if mix.shape[-1] < training_length:
|
| 492 |
+
length_pre_pad = mix.shape[-1]
|
| 493 |
+
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
| 494 |
+
z = self._spec(mix)
|
| 495 |
+
mag = self._magnitude(z).to(mix.device)
|
| 496 |
+
x = mag
|
| 497 |
+
|
| 498 |
+
B, C, Fq, T = x.shape
|
| 499 |
+
|
| 500 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
| 501 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
| 502 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
| 503 |
+
x = (x - mean) / (1e-5 + std)
|
| 504 |
+
# x will be the freq. branch input.
|
| 505 |
+
|
| 506 |
+
# Prepare the time branch input.
|
| 507 |
+
xt = mix
|
| 508 |
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
| 509 |
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
| 510 |
+
xt = (xt - meant) / (1e-5 + stdt)
|
| 511 |
+
|
| 512 |
+
# okay, this is a giant mess I know...
|
| 513 |
+
saved = [] # skip connections, freq.
|
| 514 |
+
saved_t = [] # skip connections, time.
|
| 515 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
| 516 |
+
lengths_t = [] # saved lengths for time branch.
|
| 517 |
+
for idx, encode in enumerate(self.encoder):
|
| 518 |
+
lengths.append(x.shape[-1])
|
| 519 |
+
inject = None
|
| 520 |
+
if idx < len(self.tencoder):
|
| 521 |
+
# we have not yet merged branches.
|
| 522 |
+
lengths_t.append(xt.shape[-1])
|
| 523 |
+
tenc = self.tencoder[idx]
|
| 524 |
+
xt = tenc(xt)
|
| 525 |
+
if not tenc.empty:
|
| 526 |
+
# save for skip connection
|
| 527 |
+
saved_t.append(xt)
|
| 528 |
+
else:
|
| 529 |
+
# tenc contains just the first conv., so that now time and freq.
|
| 530 |
+
# branches have the same shape and can be merged.
|
| 531 |
+
inject = xt
|
| 532 |
+
x = encode(x, inject)
|
| 533 |
+
if idx == 0 and self.freq_emb is not None:
|
| 534 |
+
# add frequency embedding to allow for non equivariant convolutions
|
| 535 |
+
# over the frequency axis.
|
| 536 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
| 537 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
| 538 |
+
x = x + self.freq_emb_scale * emb
|
| 539 |
+
|
| 540 |
+
saved.append(x)
|
| 541 |
+
if self.crosstransformer:
|
| 542 |
+
if self.bottom_channels:
|
| 543 |
+
b, c, f, t = x.shape
|
| 544 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
| 545 |
+
x = self.channel_upsampler(x)
|
| 546 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
| 547 |
+
xt = self.channel_upsampler_t(xt)
|
| 548 |
+
|
| 549 |
+
x, xt = self.crosstransformer(x, xt)
|
| 550 |
+
|
| 551 |
+
if self.bottom_channels:
|
| 552 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
| 553 |
+
x = self.channel_downsampler(x)
|
| 554 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
| 555 |
+
xt = self.channel_downsampler_t(xt)
|
| 556 |
+
|
| 557 |
+
for idx, decode in enumerate(self.decoder):
|
| 558 |
+
skip = saved.pop(-1)
|
| 559 |
+
x, pre = decode(x, skip, lengths.pop(-1))
|
| 560 |
+
# `pre` contains the output just before final transposed convolution,
|
| 561 |
+
# which is used when the freq. and time branch separate.
|
| 562 |
+
|
| 563 |
+
offset = self.depth - len(self.tdecoder)
|
| 564 |
+
if idx >= offset:
|
| 565 |
+
tdec = self.tdecoder[idx - offset]
|
| 566 |
+
length_t = lengths_t.pop(-1)
|
| 567 |
+
if tdec.empty:
|
| 568 |
+
assert pre.shape[2] == 1, pre.shape
|
| 569 |
+
pre = pre[:, :, 0]
|
| 570 |
+
xt, _ = tdec(pre, None, length_t)
|
| 571 |
+
else:
|
| 572 |
+
skip = saved_t.pop(-1)
|
| 573 |
+
xt, _ = tdec(xt, skip, length_t)
|
| 574 |
+
|
| 575 |
+
# Let's make sure we used all stored skip connections.
|
| 576 |
+
assert len(saved) == 0
|
| 577 |
+
assert len(lengths_t) == 0
|
| 578 |
+
assert len(saved_t) == 0
|
| 579 |
+
|
| 580 |
+
S = len(self.sources)
|
| 581 |
+
x = x.view(B, S, -1, Fq, T)
|
| 582 |
+
x = x * std[:, None] + mean[:, None]
|
| 583 |
+
|
| 584 |
+
# to cpu as non-cuda GPUs don't support complex numbers
|
| 585 |
+
# demucs issue #435 ##432
|
| 586 |
+
# NOTE: in this case z already is on cpu
|
| 587 |
+
# TODO: remove this when mps supports complex numbers
|
| 588 |
+
|
| 589 |
+
device_type = x.device.type
|
| 590 |
+
device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
|
| 591 |
+
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
| 592 |
+
|
| 593 |
+
if x_is_other_gpu:
|
| 594 |
+
x = x.cpu()
|
| 595 |
+
|
| 596 |
+
zout = self._mask(z, x)
|
| 597 |
+
if self.use_train_segment:
|
| 598 |
+
if self.training:
|
| 599 |
+
x = self._ispec(zout, length)
|
| 600 |
+
else:
|
| 601 |
+
x = self._ispec(zout, training_length)
|
| 602 |
+
else:
|
| 603 |
+
x = self._ispec(zout, length)
|
| 604 |
+
|
| 605 |
+
# back to other device
|
| 606 |
+
if x_is_other_gpu:
|
| 607 |
+
x = x.to(device_load)
|
| 608 |
+
|
| 609 |
+
if self.use_train_segment:
|
| 610 |
+
if self.training:
|
| 611 |
+
xt = xt.view(B, S, -1, length)
|
| 612 |
+
else:
|
| 613 |
+
xt = xt.view(B, S, -1, training_length)
|
| 614 |
+
else:
|
| 615 |
+
xt = xt.view(B, S, -1, length)
|
| 616 |
+
xt = xt * stdt[:, None] + meant[:, None]
|
| 617 |
+
x = xt + x
|
| 618 |
+
if length_pre_pad:
|
| 619 |
+
x = x[..., :length_pre_pad]
|
| 620 |
+
return x
|
audio_separator/separator/uvr_lib_v5/demucs/model.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch as th
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from .utils import capture_init, center_trim
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BLSTM(nn.Module):
|
| 16 |
+
def __init__(self, dim, layers=1):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
| 19 |
+
self.linear = nn.Linear(2 * dim, dim)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = x.permute(2, 0, 1)
|
| 23 |
+
x = self.lstm(x)[0]
|
| 24 |
+
x = self.linear(x)
|
| 25 |
+
x = x.permute(1, 2, 0)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def rescale_conv(conv, reference):
|
| 30 |
+
std = conv.weight.std().detach()
|
| 31 |
+
scale = (std / reference) ** 0.5
|
| 32 |
+
conv.weight.data /= scale
|
| 33 |
+
if conv.bias is not None:
|
| 34 |
+
conv.bias.data /= scale
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def rescale_module(module, reference):
|
| 38 |
+
for sub in module.modules():
|
| 39 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
| 40 |
+
rescale_conv(sub, reference)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def upsample(x, stride):
|
| 44 |
+
"""
|
| 45 |
+
Linear upsampling, the output will be `stride` times longer.
|
| 46 |
+
"""
|
| 47 |
+
batch, channels, time = x.size()
|
| 48 |
+
weight = th.arange(stride, device=x.device, dtype=th.float) / stride
|
| 49 |
+
x = x.view(batch, channels, time, 1)
|
| 50 |
+
out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight
|
| 51 |
+
return out.reshape(batch, channels, -1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def downsample(x, stride):
|
| 55 |
+
"""
|
| 56 |
+
Downsample x by decimation.
|
| 57 |
+
"""
|
| 58 |
+
return x[:, :, ::stride]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Demucs(nn.Module):
|
| 62 |
+
@capture_init
|
| 63 |
+
def __init__(
|
| 64 |
+
self, sources=4, audio_channels=2, channels=64, depth=6, rewrite=True, glu=True, upsample=False, rescale=0.1, kernel_size=8, stride=4, growth=2.0, lstm_layers=2, context=3, samplerate=44100
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
Args:
|
| 68 |
+
sources (int): number of sources to separate
|
| 69 |
+
audio_channels (int): stereo or mono
|
| 70 |
+
channels (int): first convolution channels
|
| 71 |
+
depth (int): number of encoder/decoder layers
|
| 72 |
+
rewrite (bool): add 1x1 convolution to each encoder layer
|
| 73 |
+
and a convolution to each decoder layer.
|
| 74 |
+
For the decoder layer, `context` gives the kernel size.
|
| 75 |
+
glu (bool): use glu instead of ReLU
|
| 76 |
+
upsample (bool): use linear upsampling with convolutions
|
| 77 |
+
Wave-U-Net style, instead of transposed convolutions
|
| 78 |
+
rescale (int): rescale initial weights of convolutions
|
| 79 |
+
to get their standard deviation closer to `rescale`
|
| 80 |
+
kernel_size (int): kernel size for convolutions
|
| 81 |
+
stride (int): stride for convolutions
|
| 82 |
+
growth (float): multiply (resp divide) number of channels by that
|
| 83 |
+
for each layer of the encoder (resp decoder)
|
| 84 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm
|
| 85 |
+
context (int): kernel size of the convolution in the
|
| 86 |
+
decoder before the transposed convolution. If > 1,
|
| 87 |
+
will provide some context from neighboring time
|
| 88 |
+
steps.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.audio_channels = audio_channels
|
| 93 |
+
self.sources = sources
|
| 94 |
+
self.kernel_size = kernel_size
|
| 95 |
+
self.context = context
|
| 96 |
+
self.stride = stride
|
| 97 |
+
self.depth = depth
|
| 98 |
+
self.upsample = upsample
|
| 99 |
+
self.channels = channels
|
| 100 |
+
self.samplerate = samplerate
|
| 101 |
+
|
| 102 |
+
self.encoder = nn.ModuleList()
|
| 103 |
+
self.decoder = nn.ModuleList()
|
| 104 |
+
|
| 105 |
+
self.final = None
|
| 106 |
+
if upsample:
|
| 107 |
+
self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1)
|
| 108 |
+
stride = 1
|
| 109 |
+
|
| 110 |
+
if glu:
|
| 111 |
+
activation = nn.GLU(dim=1)
|
| 112 |
+
ch_scale = 2
|
| 113 |
+
else:
|
| 114 |
+
activation = nn.ReLU()
|
| 115 |
+
ch_scale = 1
|
| 116 |
+
in_channels = audio_channels
|
| 117 |
+
for index in range(depth):
|
| 118 |
+
encode = []
|
| 119 |
+
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
|
| 120 |
+
if rewrite:
|
| 121 |
+
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
|
| 122 |
+
self.encoder.append(nn.Sequential(*encode))
|
| 123 |
+
|
| 124 |
+
decode = []
|
| 125 |
+
if index > 0:
|
| 126 |
+
out_channels = in_channels
|
| 127 |
+
else:
|
| 128 |
+
if upsample:
|
| 129 |
+
out_channels = channels
|
| 130 |
+
else:
|
| 131 |
+
out_channels = sources * audio_channels
|
| 132 |
+
if rewrite:
|
| 133 |
+
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
|
| 134 |
+
if upsample:
|
| 135 |
+
decode += [nn.Conv1d(channels, out_channels, kernel_size, stride=1)]
|
| 136 |
+
else:
|
| 137 |
+
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
|
| 138 |
+
if index > 0:
|
| 139 |
+
decode.append(nn.ReLU())
|
| 140 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
| 141 |
+
in_channels = channels
|
| 142 |
+
channels = int(growth * channels)
|
| 143 |
+
|
| 144 |
+
channels = in_channels
|
| 145 |
+
|
| 146 |
+
if lstm_layers:
|
| 147 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
| 148 |
+
else:
|
| 149 |
+
self.lstm = None
|
| 150 |
+
|
| 151 |
+
if rescale:
|
| 152 |
+
rescale_module(self, reference=rescale)
|
| 153 |
+
|
| 154 |
+
def valid_length(self, length):
|
| 155 |
+
"""
|
| 156 |
+
Return the nearest valid length to use with the model so that
|
| 157 |
+
there is no time steps left over in a convolutions, e.g. for all
|
| 158 |
+
layers, size of the input - kernel_size % stride = 0.
|
| 159 |
+
|
| 160 |
+
If the mixture has a valid length, the estimated sources
|
| 161 |
+
will have exactly the same length when context = 1. If context > 1,
|
| 162 |
+
the two signals can be center trimmed to match.
|
| 163 |
+
|
| 164 |
+
For training, extracts should have a valid length.For evaluation
|
| 165 |
+
on full tracks we recommend passing `pad = True` to :method:`forward`.
|
| 166 |
+
"""
|
| 167 |
+
for _ in range(self.depth):
|
| 168 |
+
if self.upsample:
|
| 169 |
+
length = math.ceil(length / self.stride) + self.kernel_size - 1
|
| 170 |
+
else:
|
| 171 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
| 172 |
+
length = max(1, length)
|
| 173 |
+
length += self.context - 1
|
| 174 |
+
for _ in range(self.depth):
|
| 175 |
+
if self.upsample:
|
| 176 |
+
length = length * self.stride + self.kernel_size - 1
|
| 177 |
+
else:
|
| 178 |
+
length = (length - 1) * self.stride + self.kernel_size
|
| 179 |
+
|
| 180 |
+
return int(length)
|
| 181 |
+
|
| 182 |
+
def forward(self, mix):
|
| 183 |
+
x = mix
|
| 184 |
+
saved = [x]
|
| 185 |
+
for encode in self.encoder:
|
| 186 |
+
x = encode(x)
|
| 187 |
+
saved.append(x)
|
| 188 |
+
if self.upsample:
|
| 189 |
+
x = downsample(x, self.stride)
|
| 190 |
+
if self.lstm:
|
| 191 |
+
x = self.lstm(x)
|
| 192 |
+
for decode in self.decoder:
|
| 193 |
+
if self.upsample:
|
| 194 |
+
x = upsample(x, stride=self.stride)
|
| 195 |
+
skip = center_trim(saved.pop(-1), x)
|
| 196 |
+
x = x + skip
|
| 197 |
+
x = decode(x)
|
| 198 |
+
if self.final:
|
| 199 |
+
skip = center_trim(saved.pop(-1), x)
|
| 200 |
+
x = th.cat([x, skip], dim=1)
|
| 201 |
+
x = self.final(x)
|
| 202 |
+
|
| 203 |
+
x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1))
|
| 204 |
+
return x
|
audio_separator/separator/uvr_lib_v5/demucs/model_v2.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import julius
|
| 10 |
+
from torch import nn
|
| 11 |
+
from .tasnet_v2 import ConvTasNet
|
| 12 |
+
|
| 13 |
+
from .utils import capture_init, center_trim
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BLSTM(nn.Module):
|
| 17 |
+
def __init__(self, dim, layers=1):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
| 20 |
+
self.linear = nn.Linear(2 * dim, dim)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
x = x.permute(2, 0, 1)
|
| 24 |
+
x = self.lstm(x)[0]
|
| 25 |
+
x = self.linear(x)
|
| 26 |
+
x = x.permute(1, 2, 0)
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def rescale_conv(conv, reference):
|
| 31 |
+
std = conv.weight.std().detach()
|
| 32 |
+
scale = (std / reference) ** 0.5
|
| 33 |
+
conv.weight.data /= scale
|
| 34 |
+
if conv.bias is not None:
|
| 35 |
+
conv.bias.data /= scale
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def rescale_module(module, reference):
|
| 39 |
+
for sub in module.modules():
|
| 40 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
| 41 |
+
rescale_conv(sub, reference)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def auto_load_demucs_model_v2(sources, demucs_model_name):
|
| 45 |
+
|
| 46 |
+
if "48" in demucs_model_name:
|
| 47 |
+
channels = 48
|
| 48 |
+
elif "unittest" in demucs_model_name:
|
| 49 |
+
channels = 4
|
| 50 |
+
else:
|
| 51 |
+
channels = 64
|
| 52 |
+
|
| 53 |
+
if "tasnet" in demucs_model_name:
|
| 54 |
+
init_demucs_model = ConvTasNet(sources, X=10)
|
| 55 |
+
else:
|
| 56 |
+
init_demucs_model = Demucs(sources, channels=channels)
|
| 57 |
+
|
| 58 |
+
return init_demucs_model
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Demucs(nn.Module):
|
| 62 |
+
@capture_init
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
sources,
|
| 66 |
+
audio_channels=2,
|
| 67 |
+
channels=64,
|
| 68 |
+
depth=6,
|
| 69 |
+
rewrite=True,
|
| 70 |
+
glu=True,
|
| 71 |
+
rescale=0.1,
|
| 72 |
+
resample=True,
|
| 73 |
+
kernel_size=8,
|
| 74 |
+
stride=4,
|
| 75 |
+
growth=2.0,
|
| 76 |
+
lstm_layers=2,
|
| 77 |
+
context=3,
|
| 78 |
+
normalize=False,
|
| 79 |
+
samplerate=44100,
|
| 80 |
+
segment_length=4 * 10 * 44100,
|
| 81 |
+
):
|
| 82 |
+
"""
|
| 83 |
+
Args:
|
| 84 |
+
sources (list[str]): list of source names
|
| 85 |
+
audio_channels (int): stereo or mono
|
| 86 |
+
channels (int): first convolution channels
|
| 87 |
+
depth (int): number of encoder/decoder layers
|
| 88 |
+
rewrite (bool): add 1x1 convolution to each encoder layer
|
| 89 |
+
and a convolution to each decoder layer.
|
| 90 |
+
For the decoder layer, `context` gives the kernel size.
|
| 91 |
+
glu (bool): use glu instead of ReLU
|
| 92 |
+
resample_input (bool): upsample x2 the input and downsample /2 the output.
|
| 93 |
+
rescale (int): rescale initial weights of convolutions
|
| 94 |
+
to get their standard deviation closer to `rescale`
|
| 95 |
+
kernel_size (int): kernel size for convolutions
|
| 96 |
+
stride (int): stride for convolutions
|
| 97 |
+
growth (float): multiply (resp divide) number of channels by that
|
| 98 |
+
for each layer of the encoder (resp decoder)
|
| 99 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm
|
| 100 |
+
context (int): kernel size of the convolution in the
|
| 101 |
+
decoder before the transposed convolution. If > 1,
|
| 102 |
+
will provide some context from neighboring time
|
| 103 |
+
steps.
|
| 104 |
+
samplerate (int): stored as meta information for easing
|
| 105 |
+
future evaluations of the model.
|
| 106 |
+
segment_length (int): stored as meta information for easing
|
| 107 |
+
future evaluations of the model. Length of the segments on which
|
| 108 |
+
the model was trained.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.audio_channels = audio_channels
|
| 113 |
+
self.sources = sources
|
| 114 |
+
self.kernel_size = kernel_size
|
| 115 |
+
self.context = context
|
| 116 |
+
self.stride = stride
|
| 117 |
+
self.depth = depth
|
| 118 |
+
self.resample = resample
|
| 119 |
+
self.channels = channels
|
| 120 |
+
self.normalize = normalize
|
| 121 |
+
self.samplerate = samplerate
|
| 122 |
+
self.segment_length = segment_length
|
| 123 |
+
|
| 124 |
+
self.encoder = nn.ModuleList()
|
| 125 |
+
self.decoder = nn.ModuleList()
|
| 126 |
+
|
| 127 |
+
if glu:
|
| 128 |
+
activation = nn.GLU(dim=1)
|
| 129 |
+
ch_scale = 2
|
| 130 |
+
else:
|
| 131 |
+
activation = nn.ReLU()
|
| 132 |
+
ch_scale = 1
|
| 133 |
+
in_channels = audio_channels
|
| 134 |
+
for index in range(depth):
|
| 135 |
+
encode = []
|
| 136 |
+
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
|
| 137 |
+
if rewrite:
|
| 138 |
+
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
|
| 139 |
+
self.encoder.append(nn.Sequential(*encode))
|
| 140 |
+
|
| 141 |
+
decode = []
|
| 142 |
+
if index > 0:
|
| 143 |
+
out_channels = in_channels
|
| 144 |
+
else:
|
| 145 |
+
out_channels = len(self.sources) * audio_channels
|
| 146 |
+
if rewrite:
|
| 147 |
+
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
|
| 148 |
+
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
|
| 149 |
+
if index > 0:
|
| 150 |
+
decode.append(nn.ReLU())
|
| 151 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
| 152 |
+
in_channels = channels
|
| 153 |
+
channels = int(growth * channels)
|
| 154 |
+
|
| 155 |
+
channels = in_channels
|
| 156 |
+
|
| 157 |
+
if lstm_layers:
|
| 158 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
| 159 |
+
else:
|
| 160 |
+
self.lstm = None
|
| 161 |
+
|
| 162 |
+
if rescale:
|
| 163 |
+
rescale_module(self, reference=rescale)
|
| 164 |
+
|
| 165 |
+
def valid_length(self, length):
|
| 166 |
+
"""
|
| 167 |
+
Return the nearest valid length to use with the model so that
|
| 168 |
+
there is no time steps left over in a convolutions, e.g. for all
|
| 169 |
+
layers, size of the input - kernel_size % stride = 0.
|
| 170 |
+
|
| 171 |
+
If the mixture has a valid length, the estimated sources
|
| 172 |
+
will have exactly the same length when context = 1. If context > 1,
|
| 173 |
+
the two signals can be center trimmed to match.
|
| 174 |
+
|
| 175 |
+
For training, extracts should have a valid length.For evaluation
|
| 176 |
+
on full tracks we recommend passing `pad = True` to :method:`forward`.
|
| 177 |
+
"""
|
| 178 |
+
if self.resample:
|
| 179 |
+
length *= 2
|
| 180 |
+
for _ in range(self.depth):
|
| 181 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
| 182 |
+
length = max(1, length)
|
| 183 |
+
length += self.context - 1
|
| 184 |
+
for _ in range(self.depth):
|
| 185 |
+
length = (length - 1) * self.stride + self.kernel_size
|
| 186 |
+
|
| 187 |
+
if self.resample:
|
| 188 |
+
length = math.ceil(length / 2)
|
| 189 |
+
return int(length)
|
| 190 |
+
|
| 191 |
+
def forward(self, mix):
|
| 192 |
+
x = mix
|
| 193 |
+
|
| 194 |
+
if self.normalize:
|
| 195 |
+
mono = mix.mean(dim=1, keepdim=True)
|
| 196 |
+
mean = mono.mean(dim=-1, keepdim=True)
|
| 197 |
+
std = mono.std(dim=-1, keepdim=True)
|
| 198 |
+
else:
|
| 199 |
+
mean = 0
|
| 200 |
+
std = 1
|
| 201 |
+
|
| 202 |
+
x = (x - mean) / (1e-5 + std)
|
| 203 |
+
|
| 204 |
+
if self.resample:
|
| 205 |
+
x = julius.resample_frac(x, 1, 2)
|
| 206 |
+
|
| 207 |
+
saved = []
|
| 208 |
+
for encode in self.encoder:
|
| 209 |
+
x = encode(x)
|
| 210 |
+
saved.append(x)
|
| 211 |
+
if self.lstm:
|
| 212 |
+
x = self.lstm(x)
|
| 213 |
+
for decode in self.decoder:
|
| 214 |
+
skip = center_trim(saved.pop(-1), x)
|
| 215 |
+
x = x + skip
|
| 216 |
+
x = decode(x)
|
| 217 |
+
|
| 218 |
+
if self.resample:
|
| 219 |
+
x = julius.resample_frac(x, 2, 1)
|
| 220 |
+
x = x * std + mean
|
| 221 |
+
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
| 222 |
+
return x
|
audio_separator/separator/uvr_lib_v5/demucs/pretrained.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""Loading pretrained models.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import typing as tp
|
| 12 |
+
|
| 13 |
+
# from dora.log import fatal
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
from diffq import DiffQuantizer
|
| 18 |
+
import torch.hub
|
| 19 |
+
|
| 20 |
+
from .model import Demucs
|
| 21 |
+
from .tasnet_v2 import ConvTasNet
|
| 22 |
+
from .utils import set_state
|
| 23 |
+
|
| 24 |
+
from .hdemucs import HDemucs
|
| 25 |
+
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
|
| 29 |
+
REMOTE_ROOT = Path(__file__).parent / "remote"
|
| 30 |
+
|
| 31 |
+
SOURCES = ["drums", "bass", "other", "vocals"]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def demucs_unittest():
|
| 35 |
+
model = HDemucs(channels=4, sources=SOURCES)
|
| 36 |
+
return model
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def add_model_flags(parser):
|
| 40 |
+
group = parser.add_mutually_exclusive_group(required=False)
|
| 41 |
+
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
|
| 42 |
+
group.add_argument("-n", "--name", default="mdx_extra_q", help="Pretrained model name or signature. Default is mdx_extra_q.")
|
| 43 |
+
parser.add_argument("--repo", type=Path, help="Folder containing all pre-trained models for use with -n.")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
|
| 47 |
+
root: str = ""
|
| 48 |
+
models: tp.Dict[str, str] = {}
|
| 49 |
+
for line in remote_file_list.read_text().split("\n"):
|
| 50 |
+
line = line.strip()
|
| 51 |
+
if line.startswith("#"):
|
| 52 |
+
continue
|
| 53 |
+
elif line.startswith("root:"):
|
| 54 |
+
root = line.split(":", 1)[1].strip()
|
| 55 |
+
else:
|
| 56 |
+
sig = line.split("-", 1)[0]
|
| 57 |
+
assert sig not in models
|
| 58 |
+
models[sig] = ROOT_URL + root + line
|
| 59 |
+
return models
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_model(name: str, repo: tp.Optional[Path] = None):
|
| 63 |
+
"""`name` must be a bag of models name or a pretrained signature
|
| 64 |
+
from the remote AWS model repo or the specified local repo if `repo` is not None.
|
| 65 |
+
"""
|
| 66 |
+
if name == "demucs_unittest":
|
| 67 |
+
return demucs_unittest()
|
| 68 |
+
model_repo: ModelOnlyRepo
|
| 69 |
+
if repo is None:
|
| 70 |
+
models = _parse_remote_files(REMOTE_ROOT / "files.txt")
|
| 71 |
+
model_repo = RemoteRepo(models)
|
| 72 |
+
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
|
| 73 |
+
else:
|
| 74 |
+
if not repo.is_dir():
|
| 75 |
+
fatal(f"{repo} must exist and be a directory.")
|
| 76 |
+
model_repo = LocalRepo(repo)
|
| 77 |
+
bag_repo = BagOnlyRepo(repo, model_repo)
|
| 78 |
+
any_repo = AnyModelRepo(model_repo, bag_repo)
|
| 79 |
+
model = any_repo.get_model(name)
|
| 80 |
+
model.eval()
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_model_from_args(args):
|
| 85 |
+
"""
|
| 86 |
+
Load local model package or pre-trained model.
|
| 87 |
+
"""
|
| 88 |
+
return get_model(name=args.name, repo=args.repo)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
logger = logging.getLogger(__name__)
|
| 92 |
+
ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
|
| 93 |
+
|
| 94 |
+
PRETRAINED_MODELS = {
|
| 95 |
+
"demucs": "e07c671f",
|
| 96 |
+
"demucs48_hq": "28a1282c",
|
| 97 |
+
"demucs_extra": "3646af93",
|
| 98 |
+
"demucs_quantized": "07afea75",
|
| 99 |
+
"tasnet": "beb46fac",
|
| 100 |
+
"tasnet_extra": "df3777b2",
|
| 101 |
+
"demucs_unittest": "09ebc15f",
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
SOURCES = ["drums", "bass", "other", "vocals"]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_url(name):
|
| 108 |
+
sig = PRETRAINED_MODELS[name]
|
| 109 |
+
return ROOT + name + "-" + sig[:8] + ".th"
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def is_pretrained(name):
|
| 113 |
+
return name in PRETRAINED_MODELS
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def load_pretrained(name):
|
| 117 |
+
if name == "demucs":
|
| 118 |
+
return demucs(pretrained=True)
|
| 119 |
+
elif name == "demucs48_hq":
|
| 120 |
+
return demucs(pretrained=True, hq=True, channels=48)
|
| 121 |
+
elif name == "demucs_extra":
|
| 122 |
+
return demucs(pretrained=True, extra=True)
|
| 123 |
+
elif name == "demucs_quantized":
|
| 124 |
+
return demucs(pretrained=True, quantized=True)
|
| 125 |
+
elif name == "demucs_unittest":
|
| 126 |
+
return demucs_unittest(pretrained=True)
|
| 127 |
+
elif name == "tasnet":
|
| 128 |
+
return tasnet(pretrained=True)
|
| 129 |
+
elif name == "tasnet_extra":
|
| 130 |
+
return tasnet(pretrained=True, extra=True)
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(f"Invalid pretrained name {name}")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _load_state(name, model, quantizer=None):
|
| 136 |
+
url = get_url(name)
|
| 137 |
+
state = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
|
| 138 |
+
set_state(model, quantizer, state)
|
| 139 |
+
if quantizer:
|
| 140 |
+
quantizer.detach()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def demucs_unittest(pretrained=True):
|
| 144 |
+
model = Demucs(channels=4, sources=SOURCES)
|
| 145 |
+
if pretrained:
|
| 146 |
+
_load_state("demucs_unittest", model)
|
| 147 |
+
return model
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
|
| 151 |
+
if not pretrained and (extra or quantized or hq):
|
| 152 |
+
raise ValueError("if extra or quantized is True, pretrained must be True.")
|
| 153 |
+
model = Demucs(sources=SOURCES, channels=channels)
|
| 154 |
+
if pretrained:
|
| 155 |
+
name = "demucs"
|
| 156 |
+
if channels != 64:
|
| 157 |
+
name += str(channels)
|
| 158 |
+
quantizer = None
|
| 159 |
+
if sum([extra, quantized, hq]) > 1:
|
| 160 |
+
raise ValueError("Only one of extra, quantized, hq, can be True.")
|
| 161 |
+
if quantized:
|
| 162 |
+
quantizer = DiffQuantizer(model, group_size=8, min_size=1)
|
| 163 |
+
name += "_quantized"
|
| 164 |
+
if extra:
|
| 165 |
+
name += "_extra"
|
| 166 |
+
if hq:
|
| 167 |
+
name += "_hq"
|
| 168 |
+
_load_state(name, model, quantizer)
|
| 169 |
+
return model
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def tasnet(pretrained=True, extra=False):
|
| 173 |
+
if not pretrained and extra:
|
| 174 |
+
raise ValueError("if extra is True, pretrained must be True.")
|
| 175 |
+
model = ConvTasNet(X=10, sources=SOURCES)
|
| 176 |
+
if pretrained:
|
| 177 |
+
name = "tasnet"
|
| 178 |
+
if extra:
|
| 179 |
+
name = "tasnet_extra"
|
| 180 |
+
_load_state(name, model)
|
| 181 |
+
return model
|
audio_separator/separator/uvr_lib_v5/demucs/repo.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""Represents a model repository, including pre-trained models and bags of models.
|
| 7 |
+
A repo can either be the main remote repository stored in AWS, or a local repository
|
| 8 |
+
with your own models.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from hashlib import sha256
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import typing as tp
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import yaml
|
| 17 |
+
|
| 18 |
+
from .apply import BagOfModels, Model
|
| 19 |
+
from .states import load_model
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
AnyModel = tp.Union[Model, BagOfModels]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ModelLoadingError(RuntimeError):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def check_checksum(path: Path, checksum: str):
|
| 30 |
+
sha = sha256()
|
| 31 |
+
with open(path, "rb") as file:
|
| 32 |
+
while True:
|
| 33 |
+
buf = file.read(2**20)
|
| 34 |
+
if not buf:
|
| 35 |
+
break
|
| 36 |
+
sha.update(buf)
|
| 37 |
+
actual_checksum = sha.hexdigest()[: len(checksum)]
|
| 38 |
+
if actual_checksum != checksum:
|
| 39 |
+
raise ModelLoadingError(f"Invalid checksum for file {path}, " f"expected {checksum} but got {actual_checksum}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ModelOnlyRepo:
|
| 43 |
+
"""Base class for all model only repos."""
|
| 44 |
+
|
| 45 |
+
def has_model(self, sig: str) -> bool:
|
| 46 |
+
raise NotImplementedError()
|
| 47 |
+
|
| 48 |
+
def get_model(self, sig: str) -> Model:
|
| 49 |
+
raise NotImplementedError()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class RemoteRepo(ModelOnlyRepo):
|
| 53 |
+
def __init__(self, models: tp.Dict[str, str]):
|
| 54 |
+
self._models = models
|
| 55 |
+
|
| 56 |
+
def has_model(self, sig: str) -> bool:
|
| 57 |
+
return sig in self._models
|
| 58 |
+
|
| 59 |
+
def get_model(self, sig: str) -> Model:
|
| 60 |
+
try:
|
| 61 |
+
url = self._models[sig]
|
| 62 |
+
except KeyError:
|
| 63 |
+
raise ModelLoadingError(f"Could not find a pre-trained model with signature {sig}.")
|
| 64 |
+
pkg = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
|
| 65 |
+
return load_model(pkg)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class LocalRepo(ModelOnlyRepo):
|
| 69 |
+
def __init__(self, root: Path):
|
| 70 |
+
self.root = root
|
| 71 |
+
self.scan()
|
| 72 |
+
|
| 73 |
+
def scan(self):
|
| 74 |
+
self._models = {}
|
| 75 |
+
self._checksums = {}
|
| 76 |
+
for file in self.root.iterdir():
|
| 77 |
+
if file.suffix == ".th":
|
| 78 |
+
if "-" in file.stem:
|
| 79 |
+
xp_sig, checksum = file.stem.split("-")
|
| 80 |
+
self._checksums[xp_sig] = checksum
|
| 81 |
+
else:
|
| 82 |
+
xp_sig = file.stem
|
| 83 |
+
if xp_sig in self._models:
|
| 84 |
+
print("Whats xp? ", xp_sig)
|
| 85 |
+
raise ModelLoadingError(f"Duplicate pre-trained model exist for signature {xp_sig}. " "Please delete all but one.")
|
| 86 |
+
self._models[xp_sig] = file
|
| 87 |
+
|
| 88 |
+
def has_model(self, sig: str) -> bool:
|
| 89 |
+
return sig in self._models
|
| 90 |
+
|
| 91 |
+
def get_model(self, sig: str) -> Model:
|
| 92 |
+
try:
|
| 93 |
+
file = self._models[sig]
|
| 94 |
+
except KeyError:
|
| 95 |
+
raise ModelLoadingError(f"Could not find pre-trained model with signature {sig}.")
|
| 96 |
+
if sig in self._checksums:
|
| 97 |
+
check_checksum(file, self._checksums[sig])
|
| 98 |
+
return load_model(file)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class BagOnlyRepo:
|
| 102 |
+
"""Handles only YAML files containing bag of models, leaving the actual
|
| 103 |
+
model loading to some Repo.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, root: Path, model_repo: ModelOnlyRepo):
|
| 107 |
+
self.root = root
|
| 108 |
+
self.model_repo = model_repo
|
| 109 |
+
self.scan()
|
| 110 |
+
|
| 111 |
+
def scan(self):
|
| 112 |
+
self._bags = {}
|
| 113 |
+
for file in self.root.iterdir():
|
| 114 |
+
if file.suffix == ".yaml":
|
| 115 |
+
self._bags[file.stem] = file
|
| 116 |
+
|
| 117 |
+
def has_model(self, name: str) -> bool:
|
| 118 |
+
return name in self._bags
|
| 119 |
+
|
| 120 |
+
def get_model(self, name: str) -> BagOfModels:
|
| 121 |
+
try:
|
| 122 |
+
yaml_file = self._bags[name]
|
| 123 |
+
except KeyError:
|
| 124 |
+
raise ModelLoadingError(f"{name} is neither a single pre-trained model or " "a bag of models.")
|
| 125 |
+
bag = yaml.safe_load(open(yaml_file))
|
| 126 |
+
signatures = bag["models"]
|
| 127 |
+
models = [self.model_repo.get_model(sig) for sig in signatures]
|
| 128 |
+
weights = bag.get("weights")
|
| 129 |
+
segment = bag.get("segment")
|
| 130 |
+
return BagOfModels(models, weights, segment)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class AnyModelRepo:
|
| 134 |
+
def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
|
| 135 |
+
self.model_repo = model_repo
|
| 136 |
+
self.bag_repo = bag_repo
|
| 137 |
+
|
| 138 |
+
def has_model(self, name_or_sig: str) -> bool:
|
| 139 |
+
return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
|
| 140 |
+
|
| 141 |
+
def get_model(self, name_or_sig: str) -> AnyModel:
|
| 142 |
+
# print('name_or_sig: ', name_or_sig)
|
| 143 |
+
if self.model_repo.has_model(name_or_sig):
|
| 144 |
+
return self.model_repo.get_model(name_or_sig)
|
| 145 |
+
else:
|
| 146 |
+
return self.bag_repo.get_model(name_or_sig)
|
audio_separator/separator/uvr_lib_v5/demucs/spec.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""Conveniance wrapper to perform STFT and iSTFT"""
|
| 7 |
+
|
| 8 |
+
import torch as th
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
| 12 |
+
*other, length = x.shape
|
| 13 |
+
x = x.reshape(-1, length)
|
| 14 |
+
|
| 15 |
+
device_type = x.device.type
|
| 16 |
+
is_other_gpu = not device_type in ["cuda", "cpu"]
|
| 17 |
+
|
| 18 |
+
if is_other_gpu:
|
| 19 |
+
x = x.cpu()
|
| 20 |
+
z = th.stft(x, n_fft * (1 + pad), hop_length or n_fft // 4, window=th.hann_window(n_fft).to(x), win_length=n_fft, normalized=True, center=True, return_complex=True, pad_mode="reflect")
|
| 21 |
+
_, freqs, frame = z.shape
|
| 22 |
+
return z.view(*other, freqs, frame)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def ispectro(z, hop_length=None, length=None, pad=0):
|
| 26 |
+
*other, freqs, frames = z.shape
|
| 27 |
+
n_fft = 2 * freqs - 2
|
| 28 |
+
z = z.view(-1, freqs, frames)
|
| 29 |
+
win_length = n_fft // (1 + pad)
|
| 30 |
+
|
| 31 |
+
device_type = z.device.type
|
| 32 |
+
is_other_gpu = not device_type in ["cuda", "cpu"]
|
| 33 |
+
|
| 34 |
+
if is_other_gpu:
|
| 35 |
+
z = z.cpu()
|
| 36 |
+
x = th.istft(z, n_fft, hop_length, window=th.hann_window(win_length).to(z.real), win_length=win_length, normalized=True, length=length, center=True)
|
| 37 |
+
_, length = x.shape
|
| 38 |
+
return x.view(*other, length)
|
audio_separator/separator/uvr_lib_v5/demucs/states.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
Utilities to save and load models.
|
| 8 |
+
"""
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
|
| 11 |
+
import functools
|
| 12 |
+
import hashlib
|
| 13 |
+
import inspect
|
| 14 |
+
import io
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_quantizer(model, args, optimizer=None):
|
| 23 |
+
"""Return the quantizer given the XP quantization args."""
|
| 24 |
+
quantizer = None
|
| 25 |
+
if args.diffq:
|
| 26 |
+
quantizer = DiffQuantizer(model, min_size=args.min_size, group_size=args.group_size)
|
| 27 |
+
if optimizer is not None:
|
| 28 |
+
quantizer.setup_optimizer(optimizer)
|
| 29 |
+
elif args.qat:
|
| 30 |
+
quantizer = UniformQuantizer(model, bits=args.qat, min_size=args.min_size)
|
| 31 |
+
return quantizer
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_model(path_or_package, strict=False):
|
| 35 |
+
"""Load a model from the given serialized model, either given as a dict (already loaded)
|
| 36 |
+
or a path to a file on disk."""
|
| 37 |
+
if isinstance(path_or_package, dict):
|
| 38 |
+
package = path_or_package
|
| 39 |
+
elif isinstance(path_or_package, (str, Path)):
|
| 40 |
+
with warnings.catch_warnings():
|
| 41 |
+
warnings.simplefilter("ignore")
|
| 42 |
+
path = path_or_package
|
| 43 |
+
package = torch.load(path, "cpu", weights_only=False)
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(f"Invalid type for {path_or_package}.")
|
| 46 |
+
|
| 47 |
+
klass = package["klass"]
|
| 48 |
+
args = package["args"]
|
| 49 |
+
kwargs = package["kwargs"]
|
| 50 |
+
|
| 51 |
+
if strict:
|
| 52 |
+
model = klass(*args, **kwargs)
|
| 53 |
+
else:
|
| 54 |
+
sig = inspect.signature(klass)
|
| 55 |
+
for key in list(kwargs):
|
| 56 |
+
if key not in sig.parameters:
|
| 57 |
+
warnings.warn("Dropping inexistant parameter " + key)
|
| 58 |
+
del kwargs[key]
|
| 59 |
+
model = klass(*args, **kwargs)
|
| 60 |
+
|
| 61 |
+
state = package["state"]
|
| 62 |
+
|
| 63 |
+
set_state(model, state)
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_state(model, quantizer, half=False):
|
| 68 |
+
"""Get the state from a model, potentially with quantization applied.
|
| 69 |
+
If `half` is True, model are stored as half precision, which shouldn't impact performance
|
| 70 |
+
but half the state size."""
|
| 71 |
+
if quantizer is None:
|
| 72 |
+
dtype = torch.half if half else None
|
| 73 |
+
state = {k: p.data.to(device="cpu", dtype=dtype) for k, p in model.state_dict().items()}
|
| 74 |
+
else:
|
| 75 |
+
state = quantizer.get_quantized_state()
|
| 76 |
+
state["__quantized"] = True
|
| 77 |
+
return state
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def set_state(model, state, quantizer=None):
|
| 81 |
+
"""Set the state on a given model."""
|
| 82 |
+
if state.get("__quantized"):
|
| 83 |
+
if quantizer is not None:
|
| 84 |
+
quantizer.restore_quantized_state(model, state["quantized"])
|
| 85 |
+
else:
|
| 86 |
+
restore_quantized_state(model, state)
|
| 87 |
+
else:
|
| 88 |
+
model.load_state_dict(state)
|
| 89 |
+
return state
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def save_with_checksum(content, path):
|
| 93 |
+
"""Save the given value on disk, along with a sha256 hash.
|
| 94 |
+
Should be used with the output of either `serialize_model` or `get_state`."""
|
| 95 |
+
buf = io.BytesIO()
|
| 96 |
+
torch.save(content, buf)
|
| 97 |
+
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
|
| 98 |
+
|
| 99 |
+
path = path.parent / (path.stem + "-" + sig + path.suffix)
|
| 100 |
+
path.write_bytes(buf.getvalue())
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def copy_state(state):
|
| 104 |
+
return {k: v.cpu().clone() for k, v in state.items()}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@contextmanager
|
| 108 |
+
def swap_state(model, state):
|
| 109 |
+
"""
|
| 110 |
+
Context manager that swaps the state of a model, e.g:
|
| 111 |
+
|
| 112 |
+
# model is in old state
|
| 113 |
+
with swap_state(model, new_state):
|
| 114 |
+
# model in new state
|
| 115 |
+
# model back to old state
|
| 116 |
+
"""
|
| 117 |
+
old_state = copy_state(model.state_dict())
|
| 118 |
+
model.load_state_dict(state, strict=False)
|
| 119 |
+
try:
|
| 120 |
+
yield
|
| 121 |
+
finally:
|
| 122 |
+
model.load_state_dict(old_state)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def capture_init(init):
|
| 126 |
+
@functools.wraps(init)
|
| 127 |
+
def __init__(self, *args, **kwargs):
|
| 128 |
+
self._init_args_kwargs = (args, kwargs)
|
| 129 |
+
init(self, *args, **kwargs)
|
| 130 |
+
|
| 131 |
+
return __init__
|
audio_separator/separator/uvr_lib_v5/demucs/tasnet.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# Created on 2018/12
|
| 8 |
+
# Author: Kaituo XU
|
| 9 |
+
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
|
| 10 |
+
# Here is the original license:
|
| 11 |
+
# The MIT License (MIT)
|
| 12 |
+
#
|
| 13 |
+
# Copyright (c) 2018 Kaituo XU
|
| 14 |
+
#
|
| 15 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 16 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 17 |
+
# in the Software without restriction, including without limitation the rights
|
| 18 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 19 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 20 |
+
# furnished to do so, subject to the following conditions:
|
| 21 |
+
#
|
| 22 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 23 |
+
# copies or substantial portions of the Software.
|
| 24 |
+
#
|
| 25 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 26 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 27 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 28 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 29 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 30 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 31 |
+
# SOFTWARE.
|
| 32 |
+
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
|
| 39 |
+
from .utils import capture_init
|
| 40 |
+
|
| 41 |
+
EPS = 1e-8
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def overlap_and_add(signal, frame_step):
|
| 45 |
+
outer_dimensions = signal.size()[:-2]
|
| 46 |
+
frames, frame_length = signal.size()[-2:]
|
| 47 |
+
|
| 48 |
+
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
| 49 |
+
subframe_step = frame_step // subframe_length
|
| 50 |
+
subframes_per_frame = frame_length // subframe_length
|
| 51 |
+
output_size = frame_step * (frames - 1) + frame_length
|
| 52 |
+
output_subframes = output_size // subframe_length
|
| 53 |
+
|
| 54 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
| 55 |
+
|
| 56 |
+
frame = torch.arange(0, output_subframes, device=signal.device).unfold(0, subframes_per_frame, subframe_step)
|
| 57 |
+
frame = frame.long() # signal may in GPU or CPU
|
| 58 |
+
frame = frame.contiguous().view(-1)
|
| 59 |
+
|
| 60 |
+
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
| 61 |
+
result.index_add_(-2, frame, subframe_signal)
|
| 62 |
+
result = result.view(*outer_dimensions, -1)
|
| 63 |
+
return result
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ConvTasNet(nn.Module):
|
| 67 |
+
@capture_init
|
| 68 |
+
def __init__(self, N=256, L=20, B=256, H=512, P=3, X=8, R=4, C=4, audio_channels=1, samplerate=44100, norm_type="gLN", causal=False, mask_nonlinear="relu"):
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
N: Number of filters in autoencoder
|
| 72 |
+
L: Length of the filters (in samples)
|
| 73 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
| 74 |
+
H: Number of channels in convolutional blocks
|
| 75 |
+
P: Kernel size in convolutional blocks
|
| 76 |
+
X: Number of convolutional blocks in each repeat
|
| 77 |
+
R: Number of repeats
|
| 78 |
+
C: Number of speakers
|
| 79 |
+
norm_type: BN, gLN, cLN
|
| 80 |
+
causal: causal or non-causal
|
| 81 |
+
mask_nonlinear: use which non-linear function to generate mask
|
| 82 |
+
"""
|
| 83 |
+
super(ConvTasNet, self).__init__()
|
| 84 |
+
# Hyper-parameter
|
| 85 |
+
self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
|
| 86 |
+
self.norm_type = norm_type
|
| 87 |
+
self.causal = causal
|
| 88 |
+
self.mask_nonlinear = mask_nonlinear
|
| 89 |
+
self.audio_channels = audio_channels
|
| 90 |
+
self.samplerate = samplerate
|
| 91 |
+
# Components
|
| 92 |
+
self.encoder = Encoder(L, N, audio_channels)
|
| 93 |
+
self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
|
| 94 |
+
self.decoder = Decoder(N, L, audio_channels)
|
| 95 |
+
# init
|
| 96 |
+
for p in self.parameters():
|
| 97 |
+
if p.dim() > 1:
|
| 98 |
+
nn.init.xavier_normal_(p)
|
| 99 |
+
|
| 100 |
+
def valid_length(self, length):
|
| 101 |
+
return length
|
| 102 |
+
|
| 103 |
+
def forward(self, mixture):
|
| 104 |
+
"""
|
| 105 |
+
Args:
|
| 106 |
+
mixture: [M, T], M is batch size, T is #samples
|
| 107 |
+
Returns:
|
| 108 |
+
est_source: [M, C, T]
|
| 109 |
+
"""
|
| 110 |
+
mixture_w = self.encoder(mixture)
|
| 111 |
+
est_mask = self.separator(mixture_w)
|
| 112 |
+
est_source = self.decoder(mixture_w, est_mask)
|
| 113 |
+
|
| 114 |
+
# T changed after conv1d in encoder, fix it here
|
| 115 |
+
T_origin = mixture.size(-1)
|
| 116 |
+
T_conv = est_source.size(-1)
|
| 117 |
+
est_source = F.pad(est_source, (0, T_origin - T_conv))
|
| 118 |
+
return est_source
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class Encoder(nn.Module):
|
| 122 |
+
"""Estimation of the nonnegative mixture weight by a 1-D conv layer."""
|
| 123 |
+
|
| 124 |
+
def __init__(self, L, N, audio_channels):
|
| 125 |
+
super(Encoder, self).__init__()
|
| 126 |
+
# Hyper-parameter
|
| 127 |
+
self.L, self.N = L, N
|
| 128 |
+
# Components
|
| 129 |
+
# 50% overlap
|
| 130 |
+
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
|
| 131 |
+
|
| 132 |
+
def forward(self, mixture):
|
| 133 |
+
"""
|
| 134 |
+
Args:
|
| 135 |
+
mixture: [M, T], M is batch size, T is #samples
|
| 136 |
+
Returns:
|
| 137 |
+
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
|
| 138 |
+
"""
|
| 139 |
+
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
|
| 140 |
+
return mixture_w
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class Decoder(nn.Module):
|
| 144 |
+
def __init__(self, N, L, audio_channels):
|
| 145 |
+
super(Decoder, self).__init__()
|
| 146 |
+
# Hyper-parameter
|
| 147 |
+
self.N, self.L = N, L
|
| 148 |
+
self.audio_channels = audio_channels
|
| 149 |
+
# Components
|
| 150 |
+
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
|
| 151 |
+
|
| 152 |
+
def forward(self, mixture_w, est_mask):
|
| 153 |
+
"""
|
| 154 |
+
Args:
|
| 155 |
+
mixture_w: [M, N, K]
|
| 156 |
+
est_mask: [M, C, N, K]
|
| 157 |
+
Returns:
|
| 158 |
+
est_source: [M, C, T]
|
| 159 |
+
"""
|
| 160 |
+
# D = W * M
|
| 161 |
+
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
|
| 162 |
+
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
|
| 163 |
+
# S = DV
|
| 164 |
+
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
|
| 165 |
+
m, c, k, _ = est_source.size()
|
| 166 |
+
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
|
| 167 |
+
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
|
| 168 |
+
return est_source
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class TemporalConvNet(nn.Module):
|
| 172 |
+
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear="relu"):
|
| 173 |
+
"""
|
| 174 |
+
Args:
|
| 175 |
+
N: Number of filters in autoencoder
|
| 176 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
| 177 |
+
H: Number of channels in convolutional blocks
|
| 178 |
+
P: Kernel size in convolutional blocks
|
| 179 |
+
X: Number of convolutional blocks in each repeat
|
| 180 |
+
R: Number of repeats
|
| 181 |
+
C: Number of speakers
|
| 182 |
+
norm_type: BN, gLN, cLN
|
| 183 |
+
causal: causal or non-causal
|
| 184 |
+
mask_nonlinear: use which non-linear function to generate mask
|
| 185 |
+
"""
|
| 186 |
+
super(TemporalConvNet, self).__init__()
|
| 187 |
+
# Hyper-parameter
|
| 188 |
+
self.C = C
|
| 189 |
+
self.mask_nonlinear = mask_nonlinear
|
| 190 |
+
# Components
|
| 191 |
+
# [M, N, K] -> [M, N, K]
|
| 192 |
+
layer_norm = ChannelwiseLayerNorm(N)
|
| 193 |
+
# [M, N, K] -> [M, B, K]
|
| 194 |
+
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
|
| 195 |
+
# [M, B, K] -> [M, B, K]
|
| 196 |
+
repeats = []
|
| 197 |
+
for r in range(R):
|
| 198 |
+
blocks = []
|
| 199 |
+
for x in range(X):
|
| 200 |
+
dilation = 2**x
|
| 201 |
+
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
|
| 202 |
+
blocks += [TemporalBlock(B, H, P, stride=1, padding=padding, dilation=dilation, norm_type=norm_type, causal=causal)]
|
| 203 |
+
repeats += [nn.Sequential(*blocks)]
|
| 204 |
+
temporal_conv_net = nn.Sequential(*repeats)
|
| 205 |
+
# [M, B, K] -> [M, C*N, K]
|
| 206 |
+
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
|
| 207 |
+
# Put together
|
| 208 |
+
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1)
|
| 209 |
+
|
| 210 |
+
def forward(self, mixture_w):
|
| 211 |
+
"""
|
| 212 |
+
Keep this API same with TasNet
|
| 213 |
+
Args:
|
| 214 |
+
mixture_w: [M, N, K], M is batch size
|
| 215 |
+
returns:
|
| 216 |
+
est_mask: [M, C, N, K]
|
| 217 |
+
"""
|
| 218 |
+
M, N, K = mixture_w.size()
|
| 219 |
+
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
|
| 220 |
+
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
|
| 221 |
+
if self.mask_nonlinear == "softmax":
|
| 222 |
+
est_mask = F.softmax(score, dim=1)
|
| 223 |
+
elif self.mask_nonlinear == "relu":
|
| 224 |
+
est_mask = F.relu(score)
|
| 225 |
+
else:
|
| 226 |
+
raise ValueError("Unsupported mask non-linear function")
|
| 227 |
+
return est_mask
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class TemporalBlock(nn.Module):
|
| 231 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
|
| 232 |
+
super(TemporalBlock, self).__init__()
|
| 233 |
+
# [M, B, K] -> [M, H, K]
|
| 234 |
+
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
| 235 |
+
prelu = nn.PReLU()
|
| 236 |
+
norm = chose_norm(norm_type, out_channels)
|
| 237 |
+
# [M, H, K] -> [M, B, K]
|
| 238 |
+
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, dilation, norm_type, causal)
|
| 239 |
+
# Put together
|
| 240 |
+
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
|
| 241 |
+
|
| 242 |
+
def forward(self, x):
|
| 243 |
+
"""
|
| 244 |
+
Args:
|
| 245 |
+
x: [M, B, K]
|
| 246 |
+
Returns:
|
| 247 |
+
[M, B, K]
|
| 248 |
+
"""
|
| 249 |
+
residual = x
|
| 250 |
+
out = self.net(x)
|
| 251 |
+
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
|
| 252 |
+
return out + residual # look like w/o F.relu is better than w/ F.relu
|
| 253 |
+
# return F.relu(out + residual)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class DepthwiseSeparableConv(nn.Module):
|
| 257 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
|
| 258 |
+
super(DepthwiseSeparableConv, self).__init__()
|
| 259 |
+
# Use `groups` option to implement depthwise convolution
|
| 260 |
+
# [M, H, K] -> [M, H, K]
|
| 261 |
+
depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False)
|
| 262 |
+
if causal:
|
| 263 |
+
chomp = Chomp1d(padding)
|
| 264 |
+
prelu = nn.PReLU()
|
| 265 |
+
norm = chose_norm(norm_type, in_channels)
|
| 266 |
+
# [M, H, K] -> [M, B, K]
|
| 267 |
+
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
| 268 |
+
# Put together
|
| 269 |
+
if causal:
|
| 270 |
+
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
|
| 271 |
+
else:
|
| 272 |
+
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
|
| 273 |
+
|
| 274 |
+
def forward(self, x):
|
| 275 |
+
"""
|
| 276 |
+
Args:
|
| 277 |
+
x: [M, H, K]
|
| 278 |
+
Returns:
|
| 279 |
+
result: [M, B, K]
|
| 280 |
+
"""
|
| 281 |
+
return self.net(x)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class Chomp1d(nn.Module):
|
| 285 |
+
"""To ensure the output length is the same as the input."""
|
| 286 |
+
|
| 287 |
+
def __init__(self, chomp_size):
|
| 288 |
+
super(Chomp1d, self).__init__()
|
| 289 |
+
self.chomp_size = chomp_size
|
| 290 |
+
|
| 291 |
+
def forward(self, x):
|
| 292 |
+
"""
|
| 293 |
+
Args:
|
| 294 |
+
x: [M, H, Kpad]
|
| 295 |
+
Returns:
|
| 296 |
+
[M, H, K]
|
| 297 |
+
"""
|
| 298 |
+
return x[:, :, : -self.chomp_size].contiguous()
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def chose_norm(norm_type, channel_size):
|
| 302 |
+
"""The input of normlization will be (M, C, K), where M is batch size,
|
| 303 |
+
C is channel size and K is sequence length.
|
| 304 |
+
"""
|
| 305 |
+
if norm_type == "gLN":
|
| 306 |
+
return GlobalLayerNorm(channel_size)
|
| 307 |
+
elif norm_type == "cLN":
|
| 308 |
+
return ChannelwiseLayerNorm(channel_size)
|
| 309 |
+
elif norm_type == "id":
|
| 310 |
+
return nn.Identity()
|
| 311 |
+
else: # norm_type == "BN":
|
| 312 |
+
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
| 313 |
+
# along M and K, so this BN usage is right.
|
| 314 |
+
return nn.BatchNorm1d(channel_size)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# TODO: Use nn.LayerNorm to impl cLN to speed up
|
| 318 |
+
class ChannelwiseLayerNorm(nn.Module):
|
| 319 |
+
"""Channel-wise Layer Normalization (cLN)"""
|
| 320 |
+
|
| 321 |
+
def __init__(self, channel_size):
|
| 322 |
+
super(ChannelwiseLayerNorm, self).__init__()
|
| 323 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
| 324 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
| 325 |
+
self.reset_parameters()
|
| 326 |
+
|
| 327 |
+
def reset_parameters(self):
|
| 328 |
+
self.gamma.data.fill_(1)
|
| 329 |
+
self.beta.data.zero_()
|
| 330 |
+
|
| 331 |
+
def forward(self, y):
|
| 332 |
+
"""
|
| 333 |
+
Args:
|
| 334 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
| 335 |
+
Returns:
|
| 336 |
+
cLN_y: [M, N, K]
|
| 337 |
+
"""
|
| 338 |
+
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
|
| 339 |
+
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
|
| 340 |
+
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
| 341 |
+
return cLN_y
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class GlobalLayerNorm(nn.Module):
|
| 345 |
+
"""Global Layer Normalization (gLN)"""
|
| 346 |
+
|
| 347 |
+
def __init__(self, channel_size):
|
| 348 |
+
super(GlobalLayerNorm, self).__init__()
|
| 349 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
| 350 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
| 351 |
+
self.reset_parameters()
|
| 352 |
+
|
| 353 |
+
def reset_parameters(self):
|
| 354 |
+
self.gamma.data.fill_(1)
|
| 355 |
+
self.beta.data.zero_()
|
| 356 |
+
|
| 357 |
+
def forward(self, y):
|
| 358 |
+
"""
|
| 359 |
+
Args:
|
| 360 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
| 361 |
+
Returns:
|
| 362 |
+
gLN_y: [M, N, K]
|
| 363 |
+
"""
|
| 364 |
+
# TODO: in torch 1.0, torch.mean() support dim list
|
| 365 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
|
| 366 |
+
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
| 367 |
+
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
| 368 |
+
return gLN_y
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
torch.manual_seed(123)
|
| 373 |
+
M, N, L, T = 2, 3, 4, 12
|
| 374 |
+
K = 2 * T // L - 1
|
| 375 |
+
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
|
| 376 |
+
mixture = torch.randint(3, (M, T))
|
| 377 |
+
# test Encoder
|
| 378 |
+
encoder = Encoder(L, N)
|
| 379 |
+
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
|
| 380 |
+
mixture_w = encoder(mixture)
|
| 381 |
+
print("mixture", mixture)
|
| 382 |
+
print("U", encoder.conv1d_U.weight)
|
| 383 |
+
print("mixture_w", mixture_w)
|
| 384 |
+
print("mixture_w size", mixture_w.size())
|
| 385 |
+
|
| 386 |
+
# test TemporalConvNet
|
| 387 |
+
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
|
| 388 |
+
est_mask = separator(mixture_w)
|
| 389 |
+
print("est_mask", est_mask)
|
| 390 |
+
|
| 391 |
+
# test Decoder
|
| 392 |
+
decoder = Decoder(N, L)
|
| 393 |
+
est_mask = torch.randint(2, (B, K, C, N))
|
| 394 |
+
est_source = decoder(mixture_w, est_mask)
|
| 395 |
+
print("est_source", est_source)
|
| 396 |
+
|
| 397 |
+
# test Conv-TasNet
|
| 398 |
+
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
|
| 399 |
+
est_source = conv_tasnet(mixture)
|
| 400 |
+
print("est_source", est_source)
|
| 401 |
+
print("est_source size", est_source.size())
|
audio_separator/separator/uvr_lib_v5/demucs/tasnet_v2.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# Created on 2018/12
|
| 8 |
+
# Author: Kaituo XU
|
| 9 |
+
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
|
| 10 |
+
# Here is the original license:
|
| 11 |
+
# The MIT License (MIT)
|
| 12 |
+
#
|
| 13 |
+
# Copyright (c) 2018 Kaituo XU
|
| 14 |
+
#
|
| 15 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 16 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 17 |
+
# in the Software without restriction, including without limitation the rights
|
| 18 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 19 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 20 |
+
# furnished to do so, subject to the following conditions:
|
| 21 |
+
#
|
| 22 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 23 |
+
# copies or substantial portions of the Software.
|
| 24 |
+
#
|
| 25 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 26 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 27 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 28 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 29 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 30 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 31 |
+
# SOFTWARE.
|
| 32 |
+
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
|
| 39 |
+
from .utils import capture_init
|
| 40 |
+
|
| 41 |
+
EPS = 1e-8
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def overlap_and_add(signal, frame_step):
|
| 45 |
+
outer_dimensions = signal.size()[:-2]
|
| 46 |
+
frames, frame_length = signal.size()[-2:]
|
| 47 |
+
|
| 48 |
+
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
| 49 |
+
subframe_step = frame_step // subframe_length
|
| 50 |
+
subframes_per_frame = frame_length // subframe_length
|
| 51 |
+
output_size = frame_step * (frames - 1) + frame_length
|
| 52 |
+
output_subframes = output_size // subframe_length
|
| 53 |
+
|
| 54 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
| 55 |
+
|
| 56 |
+
frame = torch.arange(0, output_subframes, device=signal.device).unfold(0, subframes_per_frame, subframe_step)
|
| 57 |
+
frame = frame.long() # signal may in GPU or CPU
|
| 58 |
+
frame = frame.contiguous().view(-1)
|
| 59 |
+
|
| 60 |
+
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
| 61 |
+
result.index_add_(-2, frame, subframe_signal)
|
| 62 |
+
result = result.view(*outer_dimensions, -1)
|
| 63 |
+
return result
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ConvTasNet(nn.Module):
|
| 67 |
+
@capture_init
|
| 68 |
+
def __init__(self, sources, N=256, L=20, B=256, H=512, P=3, X=8, R=4, audio_channels=2, norm_type="gLN", causal=False, mask_nonlinear="relu", samplerate=44100, segment_length=44100 * 2 * 4):
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
sources: list of sources
|
| 72 |
+
N: Number of filters in autoencoder
|
| 73 |
+
L: Length of the filters (in samples)
|
| 74 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
| 75 |
+
H: Number of channels in convolutional blocks
|
| 76 |
+
P: Kernel size in convolutional blocks
|
| 77 |
+
X: Number of convolutional blocks in each repeat
|
| 78 |
+
R: Number of repeats
|
| 79 |
+
norm_type: BN, gLN, cLN
|
| 80 |
+
causal: causal or non-causal
|
| 81 |
+
mask_nonlinear: use which non-linear function to generate mask
|
| 82 |
+
"""
|
| 83 |
+
super(ConvTasNet, self).__init__()
|
| 84 |
+
# Hyper-parameter
|
| 85 |
+
self.sources = sources
|
| 86 |
+
self.C = len(sources)
|
| 87 |
+
self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
|
| 88 |
+
self.norm_type = norm_type
|
| 89 |
+
self.causal = causal
|
| 90 |
+
self.mask_nonlinear = mask_nonlinear
|
| 91 |
+
self.audio_channels = audio_channels
|
| 92 |
+
self.samplerate = samplerate
|
| 93 |
+
self.segment_length = segment_length
|
| 94 |
+
# Components
|
| 95 |
+
self.encoder = Encoder(L, N, audio_channels)
|
| 96 |
+
self.separator = TemporalConvNet(N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
|
| 97 |
+
self.decoder = Decoder(N, L, audio_channels)
|
| 98 |
+
# init
|
| 99 |
+
for p in self.parameters():
|
| 100 |
+
if p.dim() > 1:
|
| 101 |
+
nn.init.xavier_normal_(p)
|
| 102 |
+
|
| 103 |
+
def valid_length(self, length):
|
| 104 |
+
return length
|
| 105 |
+
|
| 106 |
+
def forward(self, mixture):
|
| 107 |
+
"""
|
| 108 |
+
Args:
|
| 109 |
+
mixture: [M, T], M is batch size, T is #samples
|
| 110 |
+
Returns:
|
| 111 |
+
est_source: [M, C, T]
|
| 112 |
+
"""
|
| 113 |
+
mixture_w = self.encoder(mixture)
|
| 114 |
+
est_mask = self.separator(mixture_w)
|
| 115 |
+
est_source = self.decoder(mixture_w, est_mask)
|
| 116 |
+
|
| 117 |
+
# T changed after conv1d in encoder, fix it here
|
| 118 |
+
T_origin = mixture.size(-1)
|
| 119 |
+
T_conv = est_source.size(-1)
|
| 120 |
+
est_source = F.pad(est_source, (0, T_origin - T_conv))
|
| 121 |
+
return est_source
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Encoder(nn.Module):
|
| 125 |
+
"""Estimation of the nonnegative mixture weight by a 1-D conv layer."""
|
| 126 |
+
|
| 127 |
+
def __init__(self, L, N, audio_channels):
|
| 128 |
+
super(Encoder, self).__init__()
|
| 129 |
+
# Hyper-parameter
|
| 130 |
+
self.L, self.N = L, N
|
| 131 |
+
# Components
|
| 132 |
+
# 50% overlap
|
| 133 |
+
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
|
| 134 |
+
|
| 135 |
+
def forward(self, mixture):
|
| 136 |
+
"""
|
| 137 |
+
Args:
|
| 138 |
+
mixture: [M, T], M is batch size, T is #samples
|
| 139 |
+
Returns:
|
| 140 |
+
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
|
| 141 |
+
"""
|
| 142 |
+
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
|
| 143 |
+
return mixture_w
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class Decoder(nn.Module):
|
| 147 |
+
def __init__(self, N, L, audio_channels):
|
| 148 |
+
super(Decoder, self).__init__()
|
| 149 |
+
# Hyper-parameter
|
| 150 |
+
self.N, self.L = N, L
|
| 151 |
+
self.audio_channels = audio_channels
|
| 152 |
+
# Components
|
| 153 |
+
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
|
| 154 |
+
|
| 155 |
+
def forward(self, mixture_w, est_mask):
|
| 156 |
+
"""
|
| 157 |
+
Args:
|
| 158 |
+
mixture_w: [M, N, K]
|
| 159 |
+
est_mask: [M, C, N, K]
|
| 160 |
+
Returns:
|
| 161 |
+
est_source: [M, C, T]
|
| 162 |
+
"""
|
| 163 |
+
# D = W * M
|
| 164 |
+
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
|
| 165 |
+
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
|
| 166 |
+
# S = DV
|
| 167 |
+
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
|
| 168 |
+
m, c, k, _ = est_source.size()
|
| 169 |
+
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
|
| 170 |
+
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
|
| 171 |
+
return est_source
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class TemporalConvNet(nn.Module):
|
| 175 |
+
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear="relu"):
|
| 176 |
+
"""
|
| 177 |
+
Args:
|
| 178 |
+
N: Number of filters in autoencoder
|
| 179 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
| 180 |
+
H: Number of channels in convolutional blocks
|
| 181 |
+
P: Kernel size in convolutional blocks
|
| 182 |
+
X: Number of convolutional blocks in each repeat
|
| 183 |
+
R: Number of repeats
|
| 184 |
+
C: Number of speakers
|
| 185 |
+
norm_type: BN, gLN, cLN
|
| 186 |
+
causal: causal or non-causal
|
| 187 |
+
mask_nonlinear: use which non-linear function to generate mask
|
| 188 |
+
"""
|
| 189 |
+
super(TemporalConvNet, self).__init__()
|
| 190 |
+
# Hyper-parameter
|
| 191 |
+
self.C = C
|
| 192 |
+
self.mask_nonlinear = mask_nonlinear
|
| 193 |
+
# Components
|
| 194 |
+
# [M, N, K] -> [M, N, K]
|
| 195 |
+
layer_norm = ChannelwiseLayerNorm(N)
|
| 196 |
+
# [M, N, K] -> [M, B, K]
|
| 197 |
+
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
|
| 198 |
+
# [M, B, K] -> [M, B, K]
|
| 199 |
+
repeats = []
|
| 200 |
+
for r in range(R):
|
| 201 |
+
blocks = []
|
| 202 |
+
for x in range(X):
|
| 203 |
+
dilation = 2**x
|
| 204 |
+
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
|
| 205 |
+
blocks += [TemporalBlock(B, H, P, stride=1, padding=padding, dilation=dilation, norm_type=norm_type, causal=causal)]
|
| 206 |
+
repeats += [nn.Sequential(*blocks)]
|
| 207 |
+
temporal_conv_net = nn.Sequential(*repeats)
|
| 208 |
+
# [M, B, K] -> [M, C*N, K]
|
| 209 |
+
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
|
| 210 |
+
# Put together
|
| 211 |
+
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1)
|
| 212 |
+
|
| 213 |
+
def forward(self, mixture_w):
|
| 214 |
+
"""
|
| 215 |
+
Keep this API same with TasNet
|
| 216 |
+
Args:
|
| 217 |
+
mixture_w: [M, N, K], M is batch size
|
| 218 |
+
returns:
|
| 219 |
+
est_mask: [M, C, N, K]
|
| 220 |
+
"""
|
| 221 |
+
M, N, K = mixture_w.size()
|
| 222 |
+
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
|
| 223 |
+
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
|
| 224 |
+
if self.mask_nonlinear == "softmax":
|
| 225 |
+
est_mask = F.softmax(score, dim=1)
|
| 226 |
+
elif self.mask_nonlinear == "relu":
|
| 227 |
+
est_mask = F.relu(score)
|
| 228 |
+
else:
|
| 229 |
+
raise ValueError("Unsupported mask non-linear function")
|
| 230 |
+
return est_mask
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class TemporalBlock(nn.Module):
|
| 234 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
|
| 235 |
+
super(TemporalBlock, self).__init__()
|
| 236 |
+
# [M, B, K] -> [M, H, K]
|
| 237 |
+
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
| 238 |
+
prelu = nn.PReLU()
|
| 239 |
+
norm = chose_norm(norm_type, out_channels)
|
| 240 |
+
# [M, H, K] -> [M, B, K]
|
| 241 |
+
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, dilation, norm_type, causal)
|
| 242 |
+
# Put together
|
| 243 |
+
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
|
| 244 |
+
|
| 245 |
+
def forward(self, x):
|
| 246 |
+
"""
|
| 247 |
+
Args:
|
| 248 |
+
x: [M, B, K]
|
| 249 |
+
Returns:
|
| 250 |
+
[M, B, K]
|
| 251 |
+
"""
|
| 252 |
+
residual = x
|
| 253 |
+
out = self.net(x)
|
| 254 |
+
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
|
| 255 |
+
return out + residual # look like w/o F.relu is better than w/ F.relu
|
| 256 |
+
# return F.relu(out + residual)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class DepthwiseSeparableConv(nn.Module):
|
| 260 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type="gLN", causal=False):
|
| 261 |
+
super(DepthwiseSeparableConv, self).__init__()
|
| 262 |
+
# Use `groups` option to implement depthwise convolution
|
| 263 |
+
# [M, H, K] -> [M, H, K]
|
| 264 |
+
depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False)
|
| 265 |
+
if causal:
|
| 266 |
+
chomp = Chomp1d(padding)
|
| 267 |
+
prelu = nn.PReLU()
|
| 268 |
+
norm = chose_norm(norm_type, in_channels)
|
| 269 |
+
# [M, H, K] -> [M, B, K]
|
| 270 |
+
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
| 271 |
+
# Put together
|
| 272 |
+
if causal:
|
| 273 |
+
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
|
| 274 |
+
else:
|
| 275 |
+
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
|
| 276 |
+
|
| 277 |
+
def forward(self, x):
|
| 278 |
+
"""
|
| 279 |
+
Args:
|
| 280 |
+
x: [M, H, K]
|
| 281 |
+
Returns:
|
| 282 |
+
result: [M, B, K]
|
| 283 |
+
"""
|
| 284 |
+
return self.net(x)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class Chomp1d(nn.Module):
|
| 288 |
+
"""To ensure the output length is the same as the input."""
|
| 289 |
+
|
| 290 |
+
def __init__(self, chomp_size):
|
| 291 |
+
super(Chomp1d, self).__init__()
|
| 292 |
+
self.chomp_size = chomp_size
|
| 293 |
+
|
| 294 |
+
def forward(self, x):
|
| 295 |
+
"""
|
| 296 |
+
Args:
|
| 297 |
+
x: [M, H, Kpad]
|
| 298 |
+
Returns:
|
| 299 |
+
[M, H, K]
|
| 300 |
+
"""
|
| 301 |
+
return x[:, :, : -self.chomp_size].contiguous()
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def chose_norm(norm_type, channel_size):
|
| 305 |
+
"""The input of normlization will be (M, C, K), where M is batch size,
|
| 306 |
+
C is channel size and K is sequence length.
|
| 307 |
+
"""
|
| 308 |
+
if norm_type == "gLN":
|
| 309 |
+
return GlobalLayerNorm(channel_size)
|
| 310 |
+
elif norm_type == "cLN":
|
| 311 |
+
return ChannelwiseLayerNorm(channel_size)
|
| 312 |
+
elif norm_type == "id":
|
| 313 |
+
return nn.Identity()
|
| 314 |
+
else: # norm_type == "BN":
|
| 315 |
+
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
| 316 |
+
# along M and K, so this BN usage is right.
|
| 317 |
+
return nn.BatchNorm1d(channel_size)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# TODO: Use nn.LayerNorm to impl cLN to speed up
|
| 321 |
+
class ChannelwiseLayerNorm(nn.Module):
|
| 322 |
+
"""Channel-wise Layer Normalization (cLN)"""
|
| 323 |
+
|
| 324 |
+
def __init__(self, channel_size):
|
| 325 |
+
super(ChannelwiseLayerNorm, self).__init__()
|
| 326 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
| 327 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
| 328 |
+
self.reset_parameters()
|
| 329 |
+
|
| 330 |
+
def reset_parameters(self):
|
| 331 |
+
self.gamma.data.fill_(1)
|
| 332 |
+
self.beta.data.zero_()
|
| 333 |
+
|
| 334 |
+
def forward(self, y):
|
| 335 |
+
"""
|
| 336 |
+
Args:
|
| 337 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
| 338 |
+
Returns:
|
| 339 |
+
cLN_y: [M, N, K]
|
| 340 |
+
"""
|
| 341 |
+
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
|
| 342 |
+
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
|
| 343 |
+
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
| 344 |
+
return cLN_y
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class GlobalLayerNorm(nn.Module):
|
| 348 |
+
"""Global Layer Normalization (gLN)"""
|
| 349 |
+
|
| 350 |
+
def __init__(self, channel_size):
|
| 351 |
+
super(GlobalLayerNorm, self).__init__()
|
| 352 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
| 353 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
| 354 |
+
self.reset_parameters()
|
| 355 |
+
|
| 356 |
+
def reset_parameters(self):
|
| 357 |
+
self.gamma.data.fill_(1)
|
| 358 |
+
self.beta.data.zero_()
|
| 359 |
+
|
| 360 |
+
def forward(self, y):
|
| 361 |
+
"""
|
| 362 |
+
Args:
|
| 363 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
| 364 |
+
Returns:
|
| 365 |
+
gLN_y: [M, N, K]
|
| 366 |
+
"""
|
| 367 |
+
# TODO: in torch 1.0, torch.mean() support dim list
|
| 368 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
|
| 369 |
+
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
| 370 |
+
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
| 371 |
+
return gLN_y
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
if __name__ == "__main__":
|
| 375 |
+
torch.manual_seed(123)
|
| 376 |
+
M, N, L, T = 2, 3, 4, 12
|
| 377 |
+
K = 2 * T // L - 1
|
| 378 |
+
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
|
| 379 |
+
mixture = torch.randint(3, (M, T))
|
| 380 |
+
# test Encoder
|
| 381 |
+
encoder = Encoder(L, N)
|
| 382 |
+
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
|
| 383 |
+
mixture_w = encoder(mixture)
|
| 384 |
+
print("mixture", mixture)
|
| 385 |
+
print("U", encoder.conv1d_U.weight)
|
| 386 |
+
print("mixture_w", mixture_w)
|
| 387 |
+
print("mixture_w size", mixture_w.size())
|
| 388 |
+
|
| 389 |
+
# test TemporalConvNet
|
| 390 |
+
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
|
| 391 |
+
est_mask = separator(mixture_w)
|
| 392 |
+
print("est_mask", est_mask)
|
| 393 |
+
|
| 394 |
+
# test Decoder
|
| 395 |
+
decoder = Decoder(N, L)
|
| 396 |
+
est_mask = torch.randint(2, (B, K, C, N))
|
| 397 |
+
est_source = decoder(mixture_w, est_mask)
|
| 398 |
+
print("est_source", est_source)
|
| 399 |
+
|
| 400 |
+
# test Conv-TasNet
|
| 401 |
+
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
|
| 402 |
+
est_source = conv_tasnet(mixture)
|
| 403 |
+
print("est_source", est_source)
|
| 404 |
+
print("est_source size", est_source.size())
|
audio_separator/separator/uvr_lib_v5/demucs/transformer.py
ADDED
|
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019-present, Meta, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# First author is Simon Rouard.
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
import typing as tp
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import numpy as np
|
| 15 |
+
import math
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def create_sin_embedding(length: int, dim: int, shift: int = 0, device="cpu", max_period=10000):
|
| 20 |
+
# We aim for TBC format
|
| 21 |
+
assert dim % 2 == 0
|
| 22 |
+
pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
|
| 23 |
+
half_dim = dim // 2
|
| 24 |
+
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
| 25 |
+
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
| 26 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
|
| 30 |
+
"""
|
| 31 |
+
:param d_model: dimension of the model
|
| 32 |
+
:param height: height of the positions
|
| 33 |
+
:param width: width of the positions
|
| 34 |
+
:return: d_model*height*width position matrix
|
| 35 |
+
"""
|
| 36 |
+
if d_model % 4 != 0:
|
| 37 |
+
raise ValueError("Cannot use sin/cos positional encoding with " "odd dimension (got dim={:d})".format(d_model))
|
| 38 |
+
pe = torch.zeros(d_model, height, width)
|
| 39 |
+
# Each dimension use half of d_model
|
| 40 |
+
d_model = int(d_model / 2)
|
| 41 |
+
div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model))
|
| 42 |
+
pos_w = torch.arange(0.0, width).unsqueeze(1)
|
| 43 |
+
pos_h = torch.arange(0.0, height).unsqueeze(1)
|
| 44 |
+
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
| 45 |
+
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
| 46 |
+
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
| 47 |
+
pe[d_model + 1 :: 2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
| 48 |
+
|
| 49 |
+
return pe[None, :].to(device)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def create_sin_embedding_cape(
|
| 53 |
+
length: int,
|
| 54 |
+
dim: int,
|
| 55 |
+
batch_size: int,
|
| 56 |
+
mean_normalize: bool,
|
| 57 |
+
augment: bool, # True during training
|
| 58 |
+
max_global_shift: float = 0.0, # delta max
|
| 59 |
+
max_local_shift: float = 0.0, # epsilon max
|
| 60 |
+
max_scale: float = 1.0,
|
| 61 |
+
device: str = "cpu",
|
| 62 |
+
max_period: float = 10000.0,
|
| 63 |
+
):
|
| 64 |
+
# We aim for TBC format
|
| 65 |
+
assert dim % 2 == 0
|
| 66 |
+
pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
|
| 67 |
+
pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
|
| 68 |
+
if mean_normalize:
|
| 69 |
+
pos -= torch.nanmean(pos, dim=0, keepdim=True)
|
| 70 |
+
|
| 71 |
+
if augment:
|
| 72 |
+
delta = np.random.uniform(-max_global_shift, +max_global_shift, size=[1, batch_size, 1])
|
| 73 |
+
delta_local = np.random.uniform(-max_local_shift, +max_local_shift, size=[length, batch_size, 1])
|
| 74 |
+
log_lambdas = np.random.uniform(-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1])
|
| 75 |
+
pos = (pos + delta + delta_local) * np.exp(log_lambdas)
|
| 76 |
+
|
| 77 |
+
pos = pos.to(device)
|
| 78 |
+
|
| 79 |
+
half_dim = dim // 2
|
| 80 |
+
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
| 81 |
+
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
| 82 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1).float()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_causal_mask(length):
|
| 86 |
+
pos = torch.arange(length)
|
| 87 |
+
return pos > pos[:, None]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_elementary_mask(T1, T2, mask_type, sparse_attn_window, global_window, mask_random_seed, sparsity, device):
|
| 91 |
+
"""
|
| 92 |
+
When the input of the Decoder has length T1 and the output T2
|
| 93 |
+
The mask matrix has shape (T2, T1)
|
| 94 |
+
"""
|
| 95 |
+
assert mask_type in ["diag", "jmask", "random", "global"]
|
| 96 |
+
|
| 97 |
+
if mask_type == "global":
|
| 98 |
+
mask = torch.zeros(T2, T1, dtype=torch.bool)
|
| 99 |
+
mask[:, :global_window] = True
|
| 100 |
+
line_window = int(global_window * T2 / T1)
|
| 101 |
+
mask[:line_window, :] = True
|
| 102 |
+
|
| 103 |
+
if mask_type == "diag":
|
| 104 |
+
|
| 105 |
+
mask = torch.zeros(T2, T1, dtype=torch.bool)
|
| 106 |
+
rows = torch.arange(T2)[:, None]
|
| 107 |
+
cols = (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1)).long().clamp(0, T1 - 1)
|
| 108 |
+
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
|
| 109 |
+
|
| 110 |
+
elif mask_type == "jmask":
|
| 111 |
+
mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
|
| 112 |
+
rows = torch.arange(T2 + 2)[:, None]
|
| 113 |
+
t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
|
| 114 |
+
t = (t * (t + 1) / 2).int()
|
| 115 |
+
t = torch.cat([-t.flip(0)[:-1], t])
|
| 116 |
+
cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
|
| 117 |
+
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
|
| 118 |
+
mask = mask[1:-1, 1:-1]
|
| 119 |
+
|
| 120 |
+
elif mask_type == "random":
|
| 121 |
+
gene = torch.Generator(device=device)
|
| 122 |
+
gene.manual_seed(mask_random_seed)
|
| 123 |
+
mask = torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1) > sparsity
|
| 124 |
+
|
| 125 |
+
mask = mask.to(device)
|
| 126 |
+
return mask
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_mask(T1, T2, mask_type, sparse_attn_window, global_window, mask_random_seed, sparsity, device):
|
| 130 |
+
"""
|
| 131 |
+
Return a SparseCSRTensor mask that is a combination of elementary masks
|
| 132 |
+
mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
|
| 133 |
+
"""
|
| 134 |
+
from xformers.sparse import SparseCSRTensor
|
| 135 |
+
|
| 136 |
+
# create a list
|
| 137 |
+
mask_types = mask_type.split("_")
|
| 138 |
+
|
| 139 |
+
all_masks = [get_elementary_mask(T1, T2, mask, sparse_attn_window, global_window, mask_random_seed, sparsity, device) for mask in mask_types]
|
| 140 |
+
|
| 141 |
+
final_mask = torch.stack(all_masks).sum(axis=0) > 0
|
| 142 |
+
|
| 143 |
+
return SparseCSRTensor.from_dense(final_mask[None])
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class ScaledEmbedding(nn.Module):
|
| 147 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 1.0, boost: float = 3.0):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 150 |
+
self.embedding.weight.data *= scale / boost
|
| 151 |
+
self.boost = boost
|
| 152 |
+
|
| 153 |
+
@property
|
| 154 |
+
def weight(self):
|
| 155 |
+
return self.embedding.weight * self.boost
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
return self.embedding(x) * self.boost
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class LayerScale(nn.Module):
|
| 162 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
| 163 |
+
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, channels: int, init: float = 0, channel_last=False):
|
| 167 |
+
"""
|
| 168 |
+
channel_last = False corresponds to (B, C, T) tensors
|
| 169 |
+
channel_last = True corresponds to (T, B, C) tensors
|
| 170 |
+
"""
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.channel_last = channel_last
|
| 173 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
| 174 |
+
self.scale.data[:] = init
|
| 175 |
+
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
if self.channel_last:
|
| 178 |
+
return self.scale * x
|
| 179 |
+
else:
|
| 180 |
+
return self.scale[:, None] * x
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class MyGroupNorm(nn.GroupNorm):
|
| 184 |
+
def __init__(self, *args, **kwargs):
|
| 185 |
+
super().__init__(*args, **kwargs)
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
"""
|
| 189 |
+
x: (B, T, C)
|
| 190 |
+
if num_groups=1: Normalisation on all T and C together for each B
|
| 191 |
+
"""
|
| 192 |
+
x = x.transpose(1, 2)
|
| 193 |
+
return super().forward(x).transpose(1, 2)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
d_model,
|
| 200 |
+
nhead,
|
| 201 |
+
dim_feedforward=2048,
|
| 202 |
+
dropout=0.1,
|
| 203 |
+
activation=F.relu,
|
| 204 |
+
group_norm=0,
|
| 205 |
+
norm_first=False,
|
| 206 |
+
norm_out=False,
|
| 207 |
+
layer_norm_eps=1e-5,
|
| 208 |
+
layer_scale=False,
|
| 209 |
+
init_values=1e-4,
|
| 210 |
+
device=None,
|
| 211 |
+
dtype=None,
|
| 212 |
+
sparse=False,
|
| 213 |
+
mask_type="diag",
|
| 214 |
+
mask_random_seed=42,
|
| 215 |
+
sparse_attn_window=500,
|
| 216 |
+
global_window=50,
|
| 217 |
+
auto_sparsity=False,
|
| 218 |
+
sparsity=0.95,
|
| 219 |
+
batch_first=False,
|
| 220 |
+
):
|
| 221 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 222 |
+
super().__init__(
|
| 223 |
+
d_model=d_model,
|
| 224 |
+
nhead=nhead,
|
| 225 |
+
dim_feedforward=dim_feedforward,
|
| 226 |
+
dropout=dropout,
|
| 227 |
+
activation=activation,
|
| 228 |
+
layer_norm_eps=layer_norm_eps,
|
| 229 |
+
batch_first=batch_first,
|
| 230 |
+
norm_first=norm_first,
|
| 231 |
+
device=device,
|
| 232 |
+
dtype=dtype,
|
| 233 |
+
)
|
| 234 |
+
self.sparse = sparse
|
| 235 |
+
self.auto_sparsity = auto_sparsity
|
| 236 |
+
if sparse:
|
| 237 |
+
if not auto_sparsity:
|
| 238 |
+
self.mask_type = mask_type
|
| 239 |
+
self.sparse_attn_window = sparse_attn_window
|
| 240 |
+
self.global_window = global_window
|
| 241 |
+
self.sparsity = sparsity
|
| 242 |
+
if group_norm:
|
| 243 |
+
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 244 |
+
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 245 |
+
|
| 246 |
+
self.norm_out = None
|
| 247 |
+
if self.norm_first & norm_out:
|
| 248 |
+
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
| 249 |
+
self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 250 |
+
self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 251 |
+
|
| 252 |
+
if sparse:
|
| 253 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, auto_sparsity=sparsity if auto_sparsity else 0)
|
| 254 |
+
self.__setattr__("src_mask", torch.zeros(1, 1))
|
| 255 |
+
self.mask_random_seed = mask_random_seed
|
| 256 |
+
|
| 257 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
| 258 |
+
"""
|
| 259 |
+
if batch_first = False, src shape is (T, B, C)
|
| 260 |
+
the case where batch_first=True is not covered
|
| 261 |
+
"""
|
| 262 |
+
device = src.device
|
| 263 |
+
x = src
|
| 264 |
+
T, B, C = x.shape
|
| 265 |
+
if self.sparse and not self.auto_sparsity:
|
| 266 |
+
assert src_mask is None
|
| 267 |
+
src_mask = self.src_mask
|
| 268 |
+
if src_mask.shape[-1] != T:
|
| 269 |
+
src_mask = get_mask(T, T, self.mask_type, self.sparse_attn_window, self.global_window, self.mask_random_seed, self.sparsity, device)
|
| 270 |
+
self.__setattr__("src_mask", src_mask)
|
| 271 |
+
|
| 272 |
+
if self.norm_first:
|
| 273 |
+
x = x + self.gamma_1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
|
| 274 |
+
x = x + self.gamma_2(self._ff_block(self.norm2(x)))
|
| 275 |
+
|
| 276 |
+
if self.norm_out:
|
| 277 |
+
x = self.norm_out(x)
|
| 278 |
+
else:
|
| 279 |
+
x = self.norm1(x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)))
|
| 280 |
+
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
| 281 |
+
|
| 282 |
+
return x
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class CrossTransformerEncoderLayer(nn.Module):
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
d_model: int,
|
| 289 |
+
nhead: int,
|
| 290 |
+
dim_feedforward: int = 2048,
|
| 291 |
+
dropout: float = 0.1,
|
| 292 |
+
activation=F.relu,
|
| 293 |
+
layer_norm_eps: float = 1e-5,
|
| 294 |
+
layer_scale: bool = False,
|
| 295 |
+
init_values: float = 1e-4,
|
| 296 |
+
norm_first: bool = False,
|
| 297 |
+
group_norm: bool = False,
|
| 298 |
+
norm_out: bool = False,
|
| 299 |
+
sparse=False,
|
| 300 |
+
mask_type="diag",
|
| 301 |
+
mask_random_seed=42,
|
| 302 |
+
sparse_attn_window=500,
|
| 303 |
+
global_window=50,
|
| 304 |
+
sparsity=0.95,
|
| 305 |
+
auto_sparsity=None,
|
| 306 |
+
device=None,
|
| 307 |
+
dtype=None,
|
| 308 |
+
batch_first=False,
|
| 309 |
+
):
|
| 310 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 311 |
+
super().__init__()
|
| 312 |
+
|
| 313 |
+
self.sparse = sparse
|
| 314 |
+
self.auto_sparsity = auto_sparsity
|
| 315 |
+
if sparse:
|
| 316 |
+
if not auto_sparsity:
|
| 317 |
+
self.mask_type = mask_type
|
| 318 |
+
self.sparse_attn_window = sparse_attn_window
|
| 319 |
+
self.global_window = global_window
|
| 320 |
+
self.sparsity = sparsity
|
| 321 |
+
|
| 322 |
+
self.cross_attn: nn.Module
|
| 323 |
+
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
| 324 |
+
# Implementation of Feedforward model
|
| 325 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
| 326 |
+
self.dropout = nn.Dropout(dropout)
|
| 327 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
|
| 328 |
+
|
| 329 |
+
self.norm_first = norm_first
|
| 330 |
+
self.norm1: nn.Module
|
| 331 |
+
self.norm2: nn.Module
|
| 332 |
+
self.norm3: nn.Module
|
| 333 |
+
if group_norm:
|
| 334 |
+
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 335 |
+
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 336 |
+
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 337 |
+
else:
|
| 338 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 339 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 340 |
+
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 341 |
+
|
| 342 |
+
self.norm_out = None
|
| 343 |
+
if self.norm_first & norm_out:
|
| 344 |
+
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
| 345 |
+
|
| 346 |
+
self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 347 |
+
self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 348 |
+
|
| 349 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 350 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 351 |
+
|
| 352 |
+
# Legacy string support for activation function.
|
| 353 |
+
if isinstance(activation, str):
|
| 354 |
+
self.activation = self._get_activation_fn(activation)
|
| 355 |
+
else:
|
| 356 |
+
self.activation = activation
|
| 357 |
+
|
| 358 |
+
if sparse:
|
| 359 |
+
self.cross_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, auto_sparsity=sparsity if auto_sparsity else 0)
|
| 360 |
+
if not auto_sparsity:
|
| 361 |
+
self.__setattr__("mask", torch.zeros(1, 1))
|
| 362 |
+
self.mask_random_seed = mask_random_seed
|
| 363 |
+
|
| 364 |
+
def forward(self, q, k, mask=None):
|
| 365 |
+
"""
|
| 366 |
+
Args:
|
| 367 |
+
q: tensor of shape (T, B, C)
|
| 368 |
+
k: tensor of shape (S, B, C)
|
| 369 |
+
mask: tensor of shape (T, S)
|
| 370 |
+
|
| 371 |
+
"""
|
| 372 |
+
device = q.device
|
| 373 |
+
T, B, C = q.shape
|
| 374 |
+
S, B, C = k.shape
|
| 375 |
+
if self.sparse and not self.auto_sparsity:
|
| 376 |
+
assert mask is None
|
| 377 |
+
mask = self.mask
|
| 378 |
+
if mask.shape[-1] != S or mask.shape[-2] != T:
|
| 379 |
+
mask = get_mask(S, T, self.mask_type, self.sparse_attn_window, self.global_window, self.mask_random_seed, self.sparsity, device)
|
| 380 |
+
self.__setattr__("mask", mask)
|
| 381 |
+
|
| 382 |
+
if self.norm_first:
|
| 383 |
+
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
|
| 384 |
+
x = x + self.gamma_2(self._ff_block(self.norm3(x)))
|
| 385 |
+
if self.norm_out:
|
| 386 |
+
x = self.norm_out(x)
|
| 387 |
+
else:
|
| 388 |
+
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
|
| 389 |
+
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
| 390 |
+
|
| 391 |
+
return x
|
| 392 |
+
|
| 393 |
+
# self-attention block
|
| 394 |
+
def _ca_block(self, q, k, attn_mask=None):
|
| 395 |
+
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
|
| 396 |
+
return self.dropout1(x)
|
| 397 |
+
|
| 398 |
+
# feed forward block
|
| 399 |
+
def _ff_block(self, x):
|
| 400 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 401 |
+
return self.dropout2(x)
|
| 402 |
+
|
| 403 |
+
def _get_activation_fn(self, activation):
|
| 404 |
+
if activation == "relu":
|
| 405 |
+
return F.relu
|
| 406 |
+
elif activation == "gelu":
|
| 407 |
+
return F.gelu
|
| 408 |
+
|
| 409 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
# ----------------- MULTI-BLOCKS MODELS: -----------------------
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class CrossTransformerEncoder(nn.Module):
|
| 416 |
+
def __init__(
|
| 417 |
+
self,
|
| 418 |
+
dim: int,
|
| 419 |
+
emb: str = "sin",
|
| 420 |
+
hidden_scale: float = 4.0,
|
| 421 |
+
num_heads: int = 8,
|
| 422 |
+
num_layers: int = 6,
|
| 423 |
+
cross_first: bool = False,
|
| 424 |
+
dropout: float = 0.0,
|
| 425 |
+
max_positions: int = 1000,
|
| 426 |
+
norm_in: bool = True,
|
| 427 |
+
norm_in_group: bool = False,
|
| 428 |
+
group_norm: int = False,
|
| 429 |
+
norm_first: bool = False,
|
| 430 |
+
norm_out: bool = False,
|
| 431 |
+
max_period: float = 10000.0,
|
| 432 |
+
weight_decay: float = 0.0,
|
| 433 |
+
lr: tp.Optional[float] = None,
|
| 434 |
+
layer_scale: bool = False,
|
| 435 |
+
gelu: bool = True,
|
| 436 |
+
sin_random_shift: int = 0,
|
| 437 |
+
weight_pos_embed: float = 1.0,
|
| 438 |
+
cape_mean_normalize: bool = True,
|
| 439 |
+
cape_augment: bool = True,
|
| 440 |
+
cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
|
| 441 |
+
sparse_self_attn: bool = False,
|
| 442 |
+
sparse_cross_attn: bool = False,
|
| 443 |
+
mask_type: str = "diag",
|
| 444 |
+
mask_random_seed: int = 42,
|
| 445 |
+
sparse_attn_window: int = 500,
|
| 446 |
+
global_window: int = 50,
|
| 447 |
+
auto_sparsity: bool = False,
|
| 448 |
+
sparsity: float = 0.95,
|
| 449 |
+
):
|
| 450 |
+
super().__init__()
|
| 451 |
+
"""
|
| 452 |
+
"""
|
| 453 |
+
assert dim % num_heads == 0
|
| 454 |
+
|
| 455 |
+
hidden_dim = int(dim * hidden_scale)
|
| 456 |
+
|
| 457 |
+
self.num_layers = num_layers
|
| 458 |
+
# classic parity = 1 means that if idx%2 == 1 there is a
|
| 459 |
+
# classical encoder else there is a cross encoder
|
| 460 |
+
self.classic_parity = 1 if cross_first else 0
|
| 461 |
+
self.emb = emb
|
| 462 |
+
self.max_period = max_period
|
| 463 |
+
self.weight_decay = weight_decay
|
| 464 |
+
self.weight_pos_embed = weight_pos_embed
|
| 465 |
+
self.sin_random_shift = sin_random_shift
|
| 466 |
+
if emb == "cape":
|
| 467 |
+
self.cape_mean_normalize = cape_mean_normalize
|
| 468 |
+
self.cape_augment = cape_augment
|
| 469 |
+
self.cape_glob_loc_scale = cape_glob_loc_scale
|
| 470 |
+
if emb == "scaled":
|
| 471 |
+
self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
|
| 472 |
+
|
| 473 |
+
self.lr = lr
|
| 474 |
+
|
| 475 |
+
activation: tp.Any = F.gelu if gelu else F.relu
|
| 476 |
+
|
| 477 |
+
self.norm_in: nn.Module
|
| 478 |
+
self.norm_in_t: nn.Module
|
| 479 |
+
if norm_in:
|
| 480 |
+
self.norm_in = nn.LayerNorm(dim)
|
| 481 |
+
self.norm_in_t = nn.LayerNorm(dim)
|
| 482 |
+
elif norm_in_group:
|
| 483 |
+
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
|
| 484 |
+
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
|
| 485 |
+
else:
|
| 486 |
+
self.norm_in = nn.Identity()
|
| 487 |
+
self.norm_in_t = nn.Identity()
|
| 488 |
+
|
| 489 |
+
# spectrogram layers
|
| 490 |
+
self.layers = nn.ModuleList()
|
| 491 |
+
# temporal layers
|
| 492 |
+
self.layers_t = nn.ModuleList()
|
| 493 |
+
|
| 494 |
+
kwargs_common = {
|
| 495 |
+
"d_model": dim,
|
| 496 |
+
"nhead": num_heads,
|
| 497 |
+
"dim_feedforward": hidden_dim,
|
| 498 |
+
"dropout": dropout,
|
| 499 |
+
"activation": activation,
|
| 500 |
+
"group_norm": group_norm,
|
| 501 |
+
"norm_first": norm_first,
|
| 502 |
+
"norm_out": norm_out,
|
| 503 |
+
"layer_scale": layer_scale,
|
| 504 |
+
"mask_type": mask_type,
|
| 505 |
+
"mask_random_seed": mask_random_seed,
|
| 506 |
+
"sparse_attn_window": sparse_attn_window,
|
| 507 |
+
"global_window": global_window,
|
| 508 |
+
"sparsity": sparsity,
|
| 509 |
+
"auto_sparsity": auto_sparsity,
|
| 510 |
+
"batch_first": True,
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
kwargs_classic_encoder = dict(kwargs_common)
|
| 514 |
+
kwargs_classic_encoder.update({"sparse": sparse_self_attn})
|
| 515 |
+
kwargs_cross_encoder = dict(kwargs_common)
|
| 516 |
+
kwargs_cross_encoder.update({"sparse": sparse_cross_attn})
|
| 517 |
+
|
| 518 |
+
for idx in range(num_layers):
|
| 519 |
+
if idx % 2 == self.classic_parity:
|
| 520 |
+
|
| 521 |
+
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
| 522 |
+
self.layers_t.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
| 523 |
+
|
| 524 |
+
else:
|
| 525 |
+
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
| 526 |
+
|
| 527 |
+
self.layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
| 528 |
+
|
| 529 |
+
def forward(self, x, xt):
|
| 530 |
+
B, C, Fr, T1 = x.shape
|
| 531 |
+
pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period) # (1, C, Fr, T1)
|
| 532 |
+
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
|
| 533 |
+
x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
|
| 534 |
+
x = self.norm_in(x)
|
| 535 |
+
x = x + self.weight_pos_embed * pos_emb_2d
|
| 536 |
+
|
| 537 |
+
B, C, T2 = xt.shape
|
| 538 |
+
xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
|
| 539 |
+
pos_emb = self._get_pos_embedding(T2, B, C, x.device)
|
| 540 |
+
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
|
| 541 |
+
xt = self.norm_in_t(xt)
|
| 542 |
+
xt = xt + self.weight_pos_embed * pos_emb
|
| 543 |
+
|
| 544 |
+
for idx in range(self.num_layers):
|
| 545 |
+
if idx % 2 == self.classic_parity:
|
| 546 |
+
x = self.layers[idx](x)
|
| 547 |
+
xt = self.layers_t[idx](xt)
|
| 548 |
+
else:
|
| 549 |
+
old_x = x
|
| 550 |
+
x = self.layers[idx](x, xt)
|
| 551 |
+
xt = self.layers_t[idx](xt, old_x)
|
| 552 |
+
|
| 553 |
+
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
|
| 554 |
+
xt = rearrange(xt, "b t2 c -> b c t2")
|
| 555 |
+
return x, xt
|
| 556 |
+
|
| 557 |
+
def _get_pos_embedding(self, T, B, C, device):
|
| 558 |
+
if self.emb == "sin":
|
| 559 |
+
shift = random.randrange(self.sin_random_shift + 1)
|
| 560 |
+
pos_emb = create_sin_embedding(T, C, shift=shift, device=device, max_period=self.max_period)
|
| 561 |
+
elif self.emb == "cape":
|
| 562 |
+
if self.training:
|
| 563 |
+
pos_emb = create_sin_embedding_cape(
|
| 564 |
+
T,
|
| 565 |
+
C,
|
| 566 |
+
B,
|
| 567 |
+
device=device,
|
| 568 |
+
max_period=self.max_period,
|
| 569 |
+
mean_normalize=self.cape_mean_normalize,
|
| 570 |
+
augment=self.cape_augment,
|
| 571 |
+
max_global_shift=self.cape_glob_loc_scale[0],
|
| 572 |
+
max_local_shift=self.cape_glob_loc_scale[1],
|
| 573 |
+
max_scale=self.cape_glob_loc_scale[2],
|
| 574 |
+
)
|
| 575 |
+
else:
|
| 576 |
+
pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=False)
|
| 577 |
+
|
| 578 |
+
elif self.emb == "scaled":
|
| 579 |
+
pos = torch.arange(T, device=device)
|
| 580 |
+
pos_emb = self.position_embeddings(pos)[:, None]
|
| 581 |
+
|
| 582 |
+
return pos_emb
|
| 583 |
+
|
| 584 |
+
def make_optim_group(self):
|
| 585 |
+
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
|
| 586 |
+
if self.lr is not None:
|
| 587 |
+
group["lr"] = self.lr
|
| 588 |
+
return group
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
# Attention Modules
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
class MultiheadAttention(nn.Module):
|
| 595 |
+
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, auto_sparsity=None):
|
| 596 |
+
super().__init__()
|
| 597 |
+
assert auto_sparsity is not None, "sanity check"
|
| 598 |
+
self.num_heads = num_heads
|
| 599 |
+
self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 600 |
+
self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 601 |
+
self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 602 |
+
self.attn_drop = torch.nn.Dropout(dropout)
|
| 603 |
+
self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
|
| 604 |
+
self.proj_drop = torch.nn.Dropout(dropout)
|
| 605 |
+
self.batch_first = batch_first
|
| 606 |
+
self.auto_sparsity = auto_sparsity
|
| 607 |
+
|
| 608 |
+
def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True):
|
| 609 |
+
|
| 610 |
+
if not self.batch_first: # N, B, C
|
| 611 |
+
query = query.permute(1, 0, 2) # B, N_q, C
|
| 612 |
+
key = key.permute(1, 0, 2) # B, N_k, C
|
| 613 |
+
value = value.permute(1, 0, 2) # B, N_k, C
|
| 614 |
+
B, N_q, C = query.shape
|
| 615 |
+
B, N_k, C = key.shape
|
| 616 |
+
|
| 617 |
+
q = self.q(query).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 618 |
+
q = q.flatten(0, 1)
|
| 619 |
+
k = self.k(key).reshape(B, N_k, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 620 |
+
k = k.flatten(0, 1)
|
| 621 |
+
v = self.v(value).reshape(B, N_k, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 622 |
+
v = v.flatten(0, 1)
|
| 623 |
+
|
| 624 |
+
if self.auto_sparsity:
|
| 625 |
+
assert attn_mask is None
|
| 626 |
+
x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
|
| 627 |
+
else:
|
| 628 |
+
x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
|
| 629 |
+
x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
|
| 630 |
+
|
| 631 |
+
x = x.transpose(1, 2).reshape(B, N_q, C)
|
| 632 |
+
x = self.proj(x)
|
| 633 |
+
x = self.proj_drop(x)
|
| 634 |
+
if not self.batch_first:
|
| 635 |
+
x = x.permute(1, 0, 2)
|
| 636 |
+
return x, None
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def scaled_query_key_softmax(q, k, att_mask):
|
| 640 |
+
from xformers.ops import masked_matmul
|
| 641 |
+
|
| 642 |
+
q = q / (k.size(-1)) ** 0.5
|
| 643 |
+
att = masked_matmul(q, k.transpose(-2, -1), att_mask)
|
| 644 |
+
att = torch.nn.functional.softmax(att, -1)
|
| 645 |
+
return att
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
def scaled_dot_product_attention(q, k, v, att_mask, dropout):
|
| 649 |
+
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
|
| 650 |
+
att = dropout(att)
|
| 651 |
+
y = att @ v
|
| 652 |
+
return y
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def _compute_buckets(x, R):
|
| 656 |
+
qq = torch.einsum("btf,bfhi->bhti", x, R)
|
| 657 |
+
qq = torch.cat([qq, -qq], dim=-1)
|
| 658 |
+
buckets = qq.argmax(dim=-1)
|
| 659 |
+
|
| 660 |
+
return buckets.permute(0, 2, 1).byte().contiguous()
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
|
| 664 |
+
# assert False, "The code for the custom sparse kernel is not ready for release yet."
|
| 665 |
+
from xformers.ops import find_locations, sparse_memory_efficient_attention
|
| 666 |
+
|
| 667 |
+
n_hashes = 32
|
| 668 |
+
proj_size = 4
|
| 669 |
+
query, key, value = [x.contiguous() for x in [query, key, value]]
|
| 670 |
+
with torch.no_grad():
|
| 671 |
+
R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
|
| 672 |
+
bucket_query = _compute_buckets(query, R)
|
| 673 |
+
bucket_key = _compute_buckets(key, R)
|
| 674 |
+
row_offsets, column_indices = find_locations(bucket_query, bucket_key, sparsity, infer_sparsity)
|
| 675 |
+
return sparse_memory_efficient_attention(query, key, value, row_offsets, column_indices, attn_bias)
|
audio_separator/separator/uvr_lib_v5/demucs/utils.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
import math
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
+
import typing as tp
|
| 13 |
+
|
| 14 |
+
import errno
|
| 15 |
+
import functools
|
| 16 |
+
import hashlib
|
| 17 |
+
import inspect
|
| 18 |
+
import io
|
| 19 |
+
import os
|
| 20 |
+
import random
|
| 21 |
+
import socket
|
| 22 |
+
import tempfile
|
| 23 |
+
import warnings
|
| 24 |
+
import zlib
|
| 25 |
+
|
| 26 |
+
from diffq import UniformQuantizer, DiffQuantizer
|
| 27 |
+
import torch as th
|
| 28 |
+
import tqdm
|
| 29 |
+
from torch import distributed
|
| 30 |
+
from torch.nn import functional as F
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def unfold(a, kernel_size, stride):
|
| 36 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
| 37 |
+
with K the kernel size, by extracting frames with the given stride.
|
| 38 |
+
|
| 39 |
+
This will pad the input so that `F = ceil(T / K)`.
|
| 40 |
+
|
| 41 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
| 42 |
+
"""
|
| 43 |
+
*shape, length = a.shape
|
| 44 |
+
n_frames = math.ceil(length / stride)
|
| 45 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
| 46 |
+
a = F.pad(a, (0, tgt_length - length))
|
| 47 |
+
strides = list(a.stride())
|
| 48 |
+
assert strides[-1] == 1, "data should be contiguous"
|
| 49 |
+
strides = strides[:-1] + [stride, 1]
|
| 50 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
|
| 54 |
+
"""
|
| 55 |
+
Center trim `tensor` with respect to `reference`, along the last dimension.
|
| 56 |
+
`reference` can also be a number, representing the length to trim to.
|
| 57 |
+
If the size difference != 0 mod 2, the extra sample is removed on the right side.
|
| 58 |
+
"""
|
| 59 |
+
ref_size: int
|
| 60 |
+
if isinstance(reference, torch.Tensor):
|
| 61 |
+
ref_size = reference.size(-1)
|
| 62 |
+
else:
|
| 63 |
+
ref_size = reference
|
| 64 |
+
delta = tensor.size(-1) - ref_size
|
| 65 |
+
if delta < 0:
|
| 66 |
+
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
|
| 67 |
+
if delta:
|
| 68 |
+
tensor = tensor[..., delta // 2 : -(delta - delta // 2)]
|
| 69 |
+
return tensor
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def pull_metric(history: tp.List[dict], name: str):
|
| 73 |
+
out = []
|
| 74 |
+
for metrics in history:
|
| 75 |
+
metric = metrics
|
| 76 |
+
for part in name.split("."):
|
| 77 |
+
metric = metric[part]
|
| 78 |
+
out.append(metric)
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def EMA(beta: float = 1):
|
| 83 |
+
"""
|
| 84 |
+
Exponential Moving Average callback.
|
| 85 |
+
Returns a single function that can be called to repeatidly update the EMA
|
| 86 |
+
with a dict of metrics. The callback will return
|
| 87 |
+
the new averaged dict of metrics.
|
| 88 |
+
|
| 89 |
+
Note that for `beta=1`, this is just plain averaging.
|
| 90 |
+
"""
|
| 91 |
+
fix: tp.Dict[str, float] = defaultdict(float)
|
| 92 |
+
total: tp.Dict[str, float] = defaultdict(float)
|
| 93 |
+
|
| 94 |
+
def _update(metrics: dict, weight: float = 1) -> dict:
|
| 95 |
+
nonlocal total, fix
|
| 96 |
+
for key, value in metrics.items():
|
| 97 |
+
total[key] = total[key] * beta + weight * float(value)
|
| 98 |
+
fix[key] = fix[key] * beta + weight
|
| 99 |
+
return {key: tot / fix[key] for key, tot in total.items()}
|
| 100 |
+
|
| 101 |
+
return _update
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def sizeof_fmt(num: float, suffix: str = "B"):
|
| 105 |
+
"""
|
| 106 |
+
Given `num` bytes, return human readable size.
|
| 107 |
+
Taken from https://stackoverflow.com/a/1094933
|
| 108 |
+
"""
|
| 109 |
+
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
|
| 110 |
+
if abs(num) < 1024.0:
|
| 111 |
+
return "%3.1f%s%s" % (num, unit, suffix)
|
| 112 |
+
num /= 1024.0
|
| 113 |
+
return "%.1f%s%s" % (num, "Yi", suffix)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@contextmanager
|
| 117 |
+
def temp_filenames(count: int, delete=True):
|
| 118 |
+
names = []
|
| 119 |
+
try:
|
| 120 |
+
for _ in range(count):
|
| 121 |
+
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
| 122 |
+
yield names
|
| 123 |
+
finally:
|
| 124 |
+
if delete:
|
| 125 |
+
for name in names:
|
| 126 |
+
os.unlink(name)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def average_metric(metric, count=1.0):
|
| 130 |
+
"""
|
| 131 |
+
Average `metric` which should be a float across all hosts. `count` should be
|
| 132 |
+
the weight for this particular host (i.e. number of examples).
|
| 133 |
+
"""
|
| 134 |
+
metric = th.tensor([count, count * metric], dtype=th.float32, device="cuda")
|
| 135 |
+
distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
|
| 136 |
+
return metric[1].item() / metric[0].item()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def free_port(host="", low=20000, high=40000):
|
| 140 |
+
"""
|
| 141 |
+
Return a port number that is most likely free.
|
| 142 |
+
This could suffer from a race condition although
|
| 143 |
+
it should be quite rare.
|
| 144 |
+
"""
|
| 145 |
+
sock = socket.socket()
|
| 146 |
+
while True:
|
| 147 |
+
port = random.randint(low, high)
|
| 148 |
+
try:
|
| 149 |
+
sock.bind((host, port))
|
| 150 |
+
except OSError as error:
|
| 151 |
+
if error.errno == errno.EADDRINUSE:
|
| 152 |
+
continue
|
| 153 |
+
raise
|
| 154 |
+
return port
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def sizeof_fmt(num, suffix="B"):
|
| 158 |
+
"""
|
| 159 |
+
Given `num` bytes, return human readable size.
|
| 160 |
+
Taken from https://stackoverflow.com/a/1094933
|
| 161 |
+
"""
|
| 162 |
+
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
|
| 163 |
+
if abs(num) < 1024.0:
|
| 164 |
+
return "%3.1f%s%s" % (num, unit, suffix)
|
| 165 |
+
num /= 1024.0
|
| 166 |
+
return "%.1f%s%s" % (num, "Yi", suffix)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def human_seconds(seconds, display=".2f"):
|
| 170 |
+
"""
|
| 171 |
+
Given `seconds` seconds, return human readable duration.
|
| 172 |
+
"""
|
| 173 |
+
value = seconds * 1e6
|
| 174 |
+
ratios = [1e3, 1e3, 60, 60, 24]
|
| 175 |
+
names = ["us", "ms", "s", "min", "hrs", "days"]
|
| 176 |
+
last = names.pop(0)
|
| 177 |
+
for name, ratio in zip(names, ratios):
|
| 178 |
+
if value / ratio < 0.3:
|
| 179 |
+
break
|
| 180 |
+
value /= ratio
|
| 181 |
+
last = name
|
| 182 |
+
return f"{format(value, display)} {last}"
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class TensorChunk:
|
| 186 |
+
def __init__(self, tensor, offset=0, length=None):
|
| 187 |
+
total_length = tensor.shape[-1]
|
| 188 |
+
assert offset >= 0
|
| 189 |
+
assert offset < total_length
|
| 190 |
+
|
| 191 |
+
if length is None:
|
| 192 |
+
length = total_length - offset
|
| 193 |
+
else:
|
| 194 |
+
length = min(total_length - offset, length)
|
| 195 |
+
|
| 196 |
+
self.tensor = tensor
|
| 197 |
+
self.offset = offset
|
| 198 |
+
self.length = length
|
| 199 |
+
self.device = tensor.device
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def shape(self):
|
| 203 |
+
shape = list(self.tensor.shape)
|
| 204 |
+
shape[-1] = self.length
|
| 205 |
+
return shape
|
| 206 |
+
|
| 207 |
+
def padded(self, target_length):
|
| 208 |
+
delta = target_length - self.length
|
| 209 |
+
total_length = self.tensor.shape[-1]
|
| 210 |
+
assert delta >= 0
|
| 211 |
+
|
| 212 |
+
start = self.offset - delta // 2
|
| 213 |
+
end = start + target_length
|
| 214 |
+
|
| 215 |
+
correct_start = max(0, start)
|
| 216 |
+
correct_end = min(total_length, end)
|
| 217 |
+
|
| 218 |
+
pad_left = correct_start - start
|
| 219 |
+
pad_right = end - correct_end
|
| 220 |
+
|
| 221 |
+
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
| 222 |
+
assert out.shape[-1] == target_length
|
| 223 |
+
return out
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def tensor_chunk(tensor_or_chunk):
|
| 227 |
+
if isinstance(tensor_or_chunk, TensorChunk):
|
| 228 |
+
return tensor_or_chunk
|
| 229 |
+
else:
|
| 230 |
+
assert isinstance(tensor_or_chunk, th.Tensor)
|
| 231 |
+
return TensorChunk(tensor_or_chunk)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None):
|
| 235 |
+
"""
|
| 236 |
+
Apply model to a given mixture.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
| 240 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
| 241 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
| 242 |
+
and improves SDR by up to 0.2 points.
|
| 243 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
| 244 |
+
and predictions will be performed individually on each and concatenated.
|
| 245 |
+
Useful for model with large memory footprint like Tasnet.
|
| 246 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
channels, length = mix.size()
|
| 250 |
+
device = mix.device
|
| 251 |
+
progress_value = 0
|
| 252 |
+
|
| 253 |
+
if split:
|
| 254 |
+
out = th.zeros(4, channels, length, device=device)
|
| 255 |
+
shift = model.samplerate * 10
|
| 256 |
+
offsets = range(0, length, shift)
|
| 257 |
+
scale = 10
|
| 258 |
+
if progress:
|
| 259 |
+
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit="seconds")
|
| 260 |
+
for offset in offsets:
|
| 261 |
+
chunk = mix[..., offset : offset + shift]
|
| 262 |
+
if set_progress_bar:
|
| 263 |
+
progress_value += 1
|
| 264 |
+
set_progress_bar(0.1, (0.8 / len(offsets) * progress_value))
|
| 265 |
+
chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
|
| 266 |
+
else:
|
| 267 |
+
chunk_out = apply_model_v1(model, chunk, shifts=shifts)
|
| 268 |
+
out[..., offset : offset + shift] = chunk_out
|
| 269 |
+
offset += shift
|
| 270 |
+
return out
|
| 271 |
+
elif shifts:
|
| 272 |
+
max_shift = int(model.samplerate / 2)
|
| 273 |
+
mix = F.pad(mix, (max_shift, max_shift))
|
| 274 |
+
offsets = list(range(max_shift))
|
| 275 |
+
random.shuffle(offsets)
|
| 276 |
+
out = 0
|
| 277 |
+
for offset in offsets[:shifts]:
|
| 278 |
+
shifted = mix[..., offset : offset + length + max_shift]
|
| 279 |
+
if set_progress_bar:
|
| 280 |
+
shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar)
|
| 281 |
+
else:
|
| 282 |
+
shifted_out = apply_model_v1(model, shifted)
|
| 283 |
+
out += shifted_out[..., max_shift - offset : max_shift - offset + length]
|
| 284 |
+
out /= shifts
|
| 285 |
+
return out
|
| 286 |
+
else:
|
| 287 |
+
valid_length = model.valid_length(length)
|
| 288 |
+
delta = valid_length - length
|
| 289 |
+
padded = F.pad(mix, (delta // 2, delta - delta // 2))
|
| 290 |
+
with th.no_grad():
|
| 291 |
+
out = model(padded.unsqueeze(0))[0]
|
| 292 |
+
return center_trim(out, mix)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def apply_model_v2(model, mix, shifts=None, split=False, overlap=0.25, transition_power=1.0, progress=False, set_progress_bar=None):
|
| 296 |
+
"""
|
| 297 |
+
Apply model to a given mixture.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
| 301 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
| 302 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
| 303 |
+
and improves SDR by up to 0.2 points.
|
| 304 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
| 305 |
+
and predictions will be performed individually on each and concatenated.
|
| 306 |
+
Useful for model with large memory footprint like Tasnet.
|
| 307 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
| 311 |
+
device = mix.device
|
| 312 |
+
channels, length = mix.shape
|
| 313 |
+
progress_value = 0
|
| 314 |
+
|
| 315 |
+
if split:
|
| 316 |
+
out = th.zeros(len(model.sources), channels, length, device=device)
|
| 317 |
+
sum_weight = th.zeros(length, device=device)
|
| 318 |
+
segment = model.segment_length
|
| 319 |
+
stride = int((1 - overlap) * segment)
|
| 320 |
+
offsets = range(0, length, stride)
|
| 321 |
+
scale = stride / model.samplerate
|
| 322 |
+
if progress:
|
| 323 |
+
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit="seconds")
|
| 324 |
+
# We start from a triangle shaped weight, with maximal weight in the middle
|
| 325 |
+
# of the segment. Then we normalize and take to the power `transition_power`.
|
| 326 |
+
# Large values of transition power will lead to sharper transitions.
|
| 327 |
+
weight = th.cat([th.arange(1, segment // 2 + 1), th.arange(segment - segment // 2, 0, -1)]).to(device)
|
| 328 |
+
assert len(weight) == segment
|
| 329 |
+
# If the overlap < 50%, this will translate to linear transition when
|
| 330 |
+
# transition_power is 1.
|
| 331 |
+
weight = (weight / weight.max()) ** transition_power
|
| 332 |
+
for offset in offsets:
|
| 333 |
+
chunk = TensorChunk(mix, offset, segment)
|
| 334 |
+
if set_progress_bar:
|
| 335 |
+
progress_value += 1
|
| 336 |
+
set_progress_bar(0.1, (0.8 / len(offsets) * progress_value))
|
| 337 |
+
chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
|
| 338 |
+
else:
|
| 339 |
+
chunk_out = apply_model_v2(model, chunk, shifts=shifts)
|
| 340 |
+
chunk_length = chunk_out.shape[-1]
|
| 341 |
+
out[..., offset : offset + segment] += weight[:chunk_length] * chunk_out
|
| 342 |
+
sum_weight[offset : offset + segment] += weight[:chunk_length]
|
| 343 |
+
offset += segment
|
| 344 |
+
assert sum_weight.min() > 0
|
| 345 |
+
out /= sum_weight
|
| 346 |
+
return out
|
| 347 |
+
elif shifts:
|
| 348 |
+
max_shift = int(0.5 * model.samplerate)
|
| 349 |
+
mix = tensor_chunk(mix)
|
| 350 |
+
padded_mix = mix.padded(length + 2 * max_shift)
|
| 351 |
+
out = 0
|
| 352 |
+
for _ in range(shifts):
|
| 353 |
+
offset = random.randint(0, max_shift)
|
| 354 |
+
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
| 355 |
+
|
| 356 |
+
if set_progress_bar:
|
| 357 |
+
progress_value += 1
|
| 358 |
+
shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar)
|
| 359 |
+
else:
|
| 360 |
+
shifted_out = apply_model_v2(model, shifted)
|
| 361 |
+
out += shifted_out[..., max_shift - offset :]
|
| 362 |
+
out /= shifts
|
| 363 |
+
return out
|
| 364 |
+
else:
|
| 365 |
+
valid_length = model.valid_length(length)
|
| 366 |
+
mix = tensor_chunk(mix)
|
| 367 |
+
padded_mix = mix.padded(valid_length)
|
| 368 |
+
with th.no_grad():
|
| 369 |
+
out = model(padded_mix.unsqueeze(0))[0]
|
| 370 |
+
return center_trim(out, length)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
@contextmanager
|
| 374 |
+
def temp_filenames(count, delete=True):
|
| 375 |
+
names = []
|
| 376 |
+
try:
|
| 377 |
+
for _ in range(count):
|
| 378 |
+
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
| 379 |
+
yield names
|
| 380 |
+
finally:
|
| 381 |
+
if delete:
|
| 382 |
+
for name in names:
|
| 383 |
+
os.unlink(name)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def get_quantizer(model, args, optimizer=None):
|
| 387 |
+
quantizer = None
|
| 388 |
+
if args.diffq:
|
| 389 |
+
quantizer = DiffQuantizer(model, min_size=args.q_min_size, group_size=8)
|
| 390 |
+
if optimizer is not None:
|
| 391 |
+
quantizer.setup_optimizer(optimizer)
|
| 392 |
+
elif args.qat:
|
| 393 |
+
quantizer = UniformQuantizer(model, bits=args.qat, min_size=args.q_min_size)
|
| 394 |
+
return quantizer
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def load_model(path, strict=False):
|
| 398 |
+
with warnings.catch_warnings():
|
| 399 |
+
warnings.simplefilter("ignore")
|
| 400 |
+
load_from = path
|
| 401 |
+
package = th.load(load_from, "cpu")
|
| 402 |
+
|
| 403 |
+
klass = package["klass"]
|
| 404 |
+
args = package["args"]
|
| 405 |
+
kwargs = package["kwargs"]
|
| 406 |
+
|
| 407 |
+
if strict:
|
| 408 |
+
model = klass(*args, **kwargs)
|
| 409 |
+
else:
|
| 410 |
+
sig = inspect.signature(klass)
|
| 411 |
+
for key in list(kwargs):
|
| 412 |
+
if key not in sig.parameters:
|
| 413 |
+
warnings.warn("Dropping inexistant parameter " + key)
|
| 414 |
+
del kwargs[key]
|
| 415 |
+
model = klass(*args, **kwargs)
|
| 416 |
+
|
| 417 |
+
state = package["state"]
|
| 418 |
+
training_args = package["training_args"]
|
| 419 |
+
quantizer = get_quantizer(model, training_args)
|
| 420 |
+
|
| 421 |
+
set_state(model, quantizer, state)
|
| 422 |
+
return model
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def get_state(model, quantizer):
|
| 426 |
+
if quantizer is None:
|
| 427 |
+
state = {k: p.data.to("cpu") for k, p in model.state_dict().items()}
|
| 428 |
+
else:
|
| 429 |
+
state = quantizer.get_quantized_state()
|
| 430 |
+
buf = io.BytesIO()
|
| 431 |
+
th.save(state, buf)
|
| 432 |
+
state = {"compressed": zlib.compress(buf.getvalue())}
|
| 433 |
+
return state
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def set_state(model, quantizer, state):
|
| 437 |
+
if quantizer is None:
|
| 438 |
+
model.load_state_dict(state)
|
| 439 |
+
else:
|
| 440 |
+
buf = io.BytesIO(zlib.decompress(state["compressed"]))
|
| 441 |
+
state = th.load(buf, "cpu")
|
| 442 |
+
quantizer.restore_quantized_state(state)
|
| 443 |
+
|
| 444 |
+
return state
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def save_state(state, path):
|
| 448 |
+
buf = io.BytesIO()
|
| 449 |
+
th.save(state, buf)
|
| 450 |
+
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
|
| 451 |
+
|
| 452 |
+
path = path.parent / (path.stem + "-" + sig + path.suffix)
|
| 453 |
+
path.write_bytes(buf.getvalue())
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def save_model(model, quantizer, training_args, path):
|
| 457 |
+
args, kwargs = model._init_args_kwargs
|
| 458 |
+
klass = model.__class__
|
| 459 |
+
|
| 460 |
+
state = get_state(model, quantizer)
|
| 461 |
+
|
| 462 |
+
save_to = path
|
| 463 |
+
package = {"klass": klass, "args": args, "kwargs": kwargs, "state": state, "training_args": training_args}
|
| 464 |
+
th.save(package, save_to)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def capture_init(init):
|
| 468 |
+
@functools.wraps(init)
|
| 469 |
+
def __init__(self, *args, **kwargs):
|
| 470 |
+
self._init_args_kwargs = (args, kwargs)
|
| 471 |
+
init(self, *args, **kwargs)
|
| 472 |
+
|
| 473 |
+
return __init__
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class DummyPoolExecutor:
|
| 477 |
+
class DummyResult:
|
| 478 |
+
def __init__(self, func, *args, **kwargs):
|
| 479 |
+
self.func = func
|
| 480 |
+
self.args = args
|
| 481 |
+
self.kwargs = kwargs
|
| 482 |
+
|
| 483 |
+
def result(self):
|
| 484 |
+
return self.func(*self.args, **self.kwargs)
|
| 485 |
+
|
| 486 |
+
def __init__(self, workers=0):
|
| 487 |
+
pass
|
| 488 |
+
|
| 489 |
+
def submit(self, func, *args, **kwargs):
|
| 490 |
+
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
| 491 |
+
|
| 492 |
+
def __enter__(self):
|
| 493 |
+
return self
|
| 494 |
+
|
| 495 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
| 496 |
+
return
|
audio_separator/separator/uvr_lib_v5/mdxnet.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .modules import TFC_TDF
|
| 4 |
+
from pytorch_lightning import LightningModule
|
| 5 |
+
|
| 6 |
+
dim_s = 4
|
| 7 |
+
|
| 8 |
+
class AbstractMDXNet(LightningModule):
|
| 9 |
+
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.target_name = target_name
|
| 12 |
+
self.lr = lr
|
| 13 |
+
self.optimizer = optimizer
|
| 14 |
+
self.dim_c = dim_c
|
| 15 |
+
self.dim_f = dim_f
|
| 16 |
+
self.dim_t = dim_t
|
| 17 |
+
self.n_fft = n_fft
|
| 18 |
+
self.n_bins = n_fft // 2 + 1
|
| 19 |
+
self.hop_length = hop_length
|
| 20 |
+
self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
|
| 21 |
+
self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False)
|
| 22 |
+
|
| 23 |
+
def get_optimizer(self):
|
| 24 |
+
if self.optimizer == 'rmsprop':
|
| 25 |
+
return torch.optim.RMSprop(self.parameters(), self.lr)
|
| 26 |
+
|
| 27 |
+
if self.optimizer == 'adamw':
|
| 28 |
+
return torch.optim.AdamW(self.parameters(), self.lr)
|
| 29 |
+
|
| 30 |
+
class ConvTDFNet(AbstractMDXNet):
|
| 31 |
+
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length,
|
| 32 |
+
num_blocks, l, g, k, bn, bias, overlap):
|
| 33 |
+
|
| 34 |
+
super(ConvTDFNet, self).__init__(
|
| 35 |
+
target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap)
|
| 36 |
+
#self.save_hyperparameters()
|
| 37 |
+
|
| 38 |
+
self.num_blocks = num_blocks
|
| 39 |
+
self.l = l
|
| 40 |
+
self.g = g
|
| 41 |
+
self.k = k
|
| 42 |
+
self.bn = bn
|
| 43 |
+
self.bias = bias
|
| 44 |
+
|
| 45 |
+
if optimizer == 'rmsprop':
|
| 46 |
+
norm = nn.BatchNorm2d
|
| 47 |
+
|
| 48 |
+
if optimizer == 'adamw':
|
| 49 |
+
norm = lambda input:nn.GroupNorm(2, input)
|
| 50 |
+
|
| 51 |
+
self.n = num_blocks // 2
|
| 52 |
+
scale = (2, 2)
|
| 53 |
+
|
| 54 |
+
self.first_conv = nn.Sequential(
|
| 55 |
+
nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)),
|
| 56 |
+
norm(g),
|
| 57 |
+
nn.ReLU(),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
f = self.dim_f
|
| 61 |
+
c = g
|
| 62 |
+
self.encoding_blocks = nn.ModuleList()
|
| 63 |
+
self.ds = nn.ModuleList()
|
| 64 |
+
for i in range(self.n):
|
| 65 |
+
self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
|
| 66 |
+
self.ds.append(
|
| 67 |
+
nn.Sequential(
|
| 68 |
+
nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale),
|
| 69 |
+
norm(c + g),
|
| 70 |
+
nn.ReLU()
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
+
f = f // 2
|
| 74 |
+
c += g
|
| 75 |
+
|
| 76 |
+
self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)
|
| 77 |
+
|
| 78 |
+
self.decoding_blocks = nn.ModuleList()
|
| 79 |
+
self.us = nn.ModuleList()
|
| 80 |
+
for i in range(self.n):
|
| 81 |
+
self.us.append(
|
| 82 |
+
nn.Sequential(
|
| 83 |
+
nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale),
|
| 84 |
+
norm(c - g),
|
| 85 |
+
nn.ReLU()
|
| 86 |
+
)
|
| 87 |
+
)
|
| 88 |
+
f = f * 2
|
| 89 |
+
c -= g
|
| 90 |
+
|
| 91 |
+
self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
|
| 92 |
+
|
| 93 |
+
self.final_conv = nn.Sequential(
|
| 94 |
+
nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
|
| 99 |
+
x = self.first_conv(x)
|
| 100 |
+
|
| 101 |
+
x = x.transpose(-1, -2)
|
| 102 |
+
|
| 103 |
+
ds_outputs = []
|
| 104 |
+
for i in range(self.n):
|
| 105 |
+
x = self.encoding_blocks[i](x)
|
| 106 |
+
ds_outputs.append(x)
|
| 107 |
+
x = self.ds[i](x)
|
| 108 |
+
|
| 109 |
+
x = self.bottleneck_block(x)
|
| 110 |
+
|
| 111 |
+
for i in range(self.n):
|
| 112 |
+
x = self.us[i](x)
|
| 113 |
+
x *= ds_outputs[-i - 1]
|
| 114 |
+
x = self.decoding_blocks[i](x)
|
| 115 |
+
|
| 116 |
+
x = x.transpose(-1, -2)
|
| 117 |
+
|
| 118 |
+
x = self.final_conv(x)
|
| 119 |
+
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
class Mixer(nn.Module):
|
| 123 |
+
def __init__(self, device, mixer_path):
|
| 124 |
+
|
| 125 |
+
super(Mixer, self).__init__()
|
| 126 |
+
|
| 127 |
+
self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False)
|
| 128 |
+
|
| 129 |
+
self.load_state_dict(
|
| 130 |
+
torch.load(mixer_path, map_location=device)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def forward(self, x):
|
| 134 |
+
x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2)
|
| 135 |
+
x = self.linear(x)
|
| 136 |
+
return x.transpose(-1,-2).reshape(dim_s,2,-1)
|
audio_separator/separator/uvr_lib_v5/mixer.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea781bd52c6a523b825fa6cdbb6189f52e318edd8b17e6fe404f76f7af8caa9c
|
| 3 |
+
size 1208
|
audio_separator/separator/uvr_lib_v5/modules.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TFC(nn.Module):
|
| 6 |
+
def __init__(self, c, l, k, norm):
|
| 7 |
+
super(TFC, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.H = nn.ModuleList()
|
| 10 |
+
for i in range(l):
|
| 11 |
+
self.H.append(
|
| 12 |
+
nn.Sequential(
|
| 13 |
+
nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2),
|
| 14 |
+
norm(c),
|
| 15 |
+
nn.ReLU(),
|
| 16 |
+
)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
for h in self.H:
|
| 21 |
+
x = h(x)
|
| 22 |
+
return x
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DenseTFC(nn.Module):
|
| 26 |
+
def __init__(self, c, l, k, norm):
|
| 27 |
+
super(DenseTFC, self).__init__()
|
| 28 |
+
|
| 29 |
+
self.conv = nn.ModuleList()
|
| 30 |
+
for i in range(l):
|
| 31 |
+
self.conv.append(
|
| 32 |
+
nn.Sequential(
|
| 33 |
+
nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2),
|
| 34 |
+
norm(c),
|
| 35 |
+
nn.ReLU(),
|
| 36 |
+
)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
for layer in self.conv[:-1]:
|
| 41 |
+
x = torch.cat([layer(x), x], 1)
|
| 42 |
+
return self.conv[-1](x)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TFC_TDF(nn.Module):
|
| 46 |
+
def __init__(self, c, l, f, k, bn, dense=False, bias=True, norm=nn.BatchNorm2d):
|
| 47 |
+
|
| 48 |
+
super(TFC_TDF, self).__init__()
|
| 49 |
+
|
| 50 |
+
self.use_tdf = bn is not None
|
| 51 |
+
|
| 52 |
+
self.tfc = DenseTFC(c, l, k, norm) if dense else TFC(c, l, k, norm)
|
| 53 |
+
|
| 54 |
+
if self.use_tdf:
|
| 55 |
+
if bn == 0:
|
| 56 |
+
self.tdf = nn.Sequential(
|
| 57 |
+
nn.Linear(f, f, bias=bias),
|
| 58 |
+
norm(c),
|
| 59 |
+
nn.ReLU()
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
self.tdf = nn.Sequential(
|
| 63 |
+
nn.Linear(f, f // bn, bias=bias),
|
| 64 |
+
norm(c),
|
| 65 |
+
nn.ReLU(),
|
| 66 |
+
nn.Linear(f // bn, f, bias=bias),
|
| 67 |
+
norm(c),
|
| 68 |
+
nn.ReLU()
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
x = self.tfc(x)
|
| 73 |
+
return x + self.tdf(x) if self.use_tdf else x
|
| 74 |
+
|
audio_separator/separator/uvr_lib_v5/playsound.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
logger = logging.getLogger(__name__)
|
| 3 |
+
|
| 4 |
+
class PlaysoundException(Exception):
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
def _canonicalizePath(path):
|
| 8 |
+
"""
|
| 9 |
+
Support passing in a pathlib.Path-like object by converting to str.
|
| 10 |
+
"""
|
| 11 |
+
import sys
|
| 12 |
+
if sys.version_info[0] >= 3:
|
| 13 |
+
return str(path)
|
| 14 |
+
else:
|
| 15 |
+
# On earlier Python versions, str is a byte string, so attempting to
|
| 16 |
+
# convert a unicode string to str will fail. Leave it alone in this case.
|
| 17 |
+
return path
|
| 18 |
+
|
| 19 |
+
def _playsoundWin(sound, block = True):
|
| 20 |
+
'''
|
| 21 |
+
Utilizes windll.winmm. Tested and known to work with MP3 and WAVE on
|
| 22 |
+
Windows 7 with Python 2.7. Probably works with more file formats.
|
| 23 |
+
Probably works on Windows XP thru Windows 10. Probably works with all
|
| 24 |
+
versions of Python.
|
| 25 |
+
|
| 26 |
+
Inspired by (but not copied from) Michael Gundlach <gundlach@gmail.com>'s mp3play:
|
| 27 |
+
https://github.com/michaelgundlach/mp3play
|
| 28 |
+
|
| 29 |
+
I never would have tried using windll.winmm without seeing his code.
|
| 30 |
+
'''
|
| 31 |
+
sound = '"' + _canonicalizePath(sound) + '"'
|
| 32 |
+
|
| 33 |
+
from ctypes import create_unicode_buffer, windll, wintypes
|
| 34 |
+
windll.winmm.mciSendStringW.argtypes = [wintypes.LPCWSTR, wintypes.LPWSTR, wintypes.UINT, wintypes.HANDLE]
|
| 35 |
+
windll.winmm.mciGetErrorStringW.argtypes = [wintypes.DWORD, wintypes.LPWSTR, wintypes.UINT]
|
| 36 |
+
|
| 37 |
+
def winCommand(*command):
|
| 38 |
+
bufLen = 600
|
| 39 |
+
buf = create_unicode_buffer(bufLen)
|
| 40 |
+
command = ' '.join(command)
|
| 41 |
+
errorCode = int(windll.winmm.mciSendStringW(command, buf, bufLen - 1, 0)) # use widestring version of the function
|
| 42 |
+
if errorCode:
|
| 43 |
+
errorBuffer = create_unicode_buffer(bufLen)
|
| 44 |
+
windll.winmm.mciGetErrorStringW(errorCode, errorBuffer, bufLen - 1) # use widestring version of the function
|
| 45 |
+
exceptionMessage = ('\n Error ' + str(errorCode) + ' for command:'
|
| 46 |
+
'\n ' + command +
|
| 47 |
+
'\n ' + errorBuffer.value)
|
| 48 |
+
logger.error(exceptionMessage)
|
| 49 |
+
raise PlaysoundException(exceptionMessage)
|
| 50 |
+
return buf.value
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
logger.debug('Starting')
|
| 54 |
+
winCommand(u'open {}'.format(sound))
|
| 55 |
+
winCommand(u'play {}{}'.format(sound, ' wait' if block else ''))
|
| 56 |
+
logger.debug('Returning')
|
| 57 |
+
finally:
|
| 58 |
+
try:
|
| 59 |
+
winCommand(u'close {}'.format(sound))
|
| 60 |
+
except PlaysoundException:
|
| 61 |
+
logger.warning(u'Failed to close the file: {}'.format(sound))
|
| 62 |
+
# If it fails, there's nothing more that can be done...
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
def _handlePathOSX(sound):
|
| 66 |
+
sound = _canonicalizePath(sound)
|
| 67 |
+
|
| 68 |
+
if '://' not in sound:
|
| 69 |
+
if not sound.startswith('/'):
|
| 70 |
+
from os import getcwd
|
| 71 |
+
sound = getcwd() + '/' + sound
|
| 72 |
+
sound = 'file://' + sound
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
# Don't double-encode it.
|
| 76 |
+
sound.encode('ascii')
|
| 77 |
+
return sound.replace(' ', '%20')
|
| 78 |
+
except UnicodeEncodeError:
|
| 79 |
+
try:
|
| 80 |
+
from urllib.parse import quote # Try the Python 3 import first...
|
| 81 |
+
except ImportError:
|
| 82 |
+
from urllib import quote # Try using the Python 2 import before giving up entirely...
|
| 83 |
+
|
| 84 |
+
parts = sound.split('://', 1)
|
| 85 |
+
return parts[0] + '://' + quote(parts[1].encode('utf-8')).replace(' ', '%20')
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _playsoundOSX(sound, block = True):
|
| 89 |
+
'''
|
| 90 |
+
Utilizes AppKit.NSSound. Tested and known to work with MP3 and WAVE on
|
| 91 |
+
OS X 10.11 with Python 2.7. Probably works with anything QuickTime supports.
|
| 92 |
+
Probably works on OS X 10.5 and newer. Probably works with all versions of
|
| 93 |
+
Python.
|
| 94 |
+
|
| 95 |
+
Inspired by (but not copied from) Aaron's Stack Overflow answer here:
|
| 96 |
+
http://stackoverflow.com/a/34568298/901641
|
| 97 |
+
|
| 98 |
+
I never would have tried using AppKit.NSSound without seeing his code.
|
| 99 |
+
'''
|
| 100 |
+
try:
|
| 101 |
+
from AppKit import NSSound
|
| 102 |
+
except ImportError:
|
| 103 |
+
logger.warning("playsound could not find a copy of AppKit - falling back to using macOS's system copy.")
|
| 104 |
+
sys.path.append('/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/PyObjC')
|
| 105 |
+
from AppKit import NSSound
|
| 106 |
+
|
| 107 |
+
from Foundation import NSURL
|
| 108 |
+
from time import sleep
|
| 109 |
+
|
| 110 |
+
sound = _handlePathOSX(sound)
|
| 111 |
+
url = NSURL.URLWithString_(sound)
|
| 112 |
+
if not url:
|
| 113 |
+
raise PlaysoundException('Cannot find a sound with filename: ' + sound)
|
| 114 |
+
|
| 115 |
+
for i in range(5):
|
| 116 |
+
nssound = NSSound.alloc().initWithContentsOfURL_byReference_(url, True)
|
| 117 |
+
if nssound:
|
| 118 |
+
break
|
| 119 |
+
else:
|
| 120 |
+
logger.debug('Failed to load sound, although url was good... ' + sound)
|
| 121 |
+
else:
|
| 122 |
+
raise PlaysoundException('Could not load sound with filename, although URL was good... ' + sound)
|
| 123 |
+
nssound.play()
|
| 124 |
+
|
| 125 |
+
if block:
|
| 126 |
+
sleep(nssound.duration())
|
| 127 |
+
|
| 128 |
+
def _playsoundNix(sound, block = True):
|
| 129 |
+
"""Play a sound using GStreamer.
|
| 130 |
+
|
| 131 |
+
Inspired by this:
|
| 132 |
+
https://gstreamer.freedesktop.org/documentation/tutorials/playback/playbin-usage.html
|
| 133 |
+
"""
|
| 134 |
+
sound = _canonicalizePath(sound)
|
| 135 |
+
|
| 136 |
+
# pathname2url escapes non-URL-safe characters
|
| 137 |
+
from os.path import abspath, exists
|
| 138 |
+
try:
|
| 139 |
+
from urllib.request import pathname2url
|
| 140 |
+
except ImportError:
|
| 141 |
+
# python 2
|
| 142 |
+
from urllib import pathname2url
|
| 143 |
+
|
| 144 |
+
import gi
|
| 145 |
+
gi.require_version('Gst', '1.0')
|
| 146 |
+
from gi.repository import Gst
|
| 147 |
+
|
| 148 |
+
Gst.init(None)
|
| 149 |
+
|
| 150 |
+
playbin = Gst.ElementFactory.make('playbin', 'playbin')
|
| 151 |
+
if sound.startswith(('http://', 'https://')):
|
| 152 |
+
playbin.props.uri = sound
|
| 153 |
+
else:
|
| 154 |
+
path = abspath(sound)
|
| 155 |
+
if not exists(path):
|
| 156 |
+
raise PlaysoundException(u'File not found: {}'.format(path))
|
| 157 |
+
playbin.props.uri = 'file://' + pathname2url(path)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
set_result = playbin.set_state(Gst.State.PLAYING)
|
| 161 |
+
if set_result != Gst.StateChangeReturn.ASYNC:
|
| 162 |
+
raise PlaysoundException(
|
| 163 |
+
"playbin.set_state returned " + repr(set_result))
|
| 164 |
+
|
| 165 |
+
# FIXME: use some other bus method than poll() with block=False
|
| 166 |
+
# https://lazka.github.io/pgi-docs/#Gst-1.0/classes/Bus.html
|
| 167 |
+
logger.debug('Starting play')
|
| 168 |
+
if block:
|
| 169 |
+
bus = playbin.get_bus()
|
| 170 |
+
try:
|
| 171 |
+
bus.poll(Gst.MessageType.EOS, Gst.CLOCK_TIME_NONE)
|
| 172 |
+
finally:
|
| 173 |
+
playbin.set_state(Gst.State.NULL)
|
| 174 |
+
|
| 175 |
+
logger.debug('Finishing play')
|
| 176 |
+
|
| 177 |
+
def _playsoundAnotherPython(otherPython, sound, block = True, macOS = False):
|
| 178 |
+
'''
|
| 179 |
+
Mostly written so that when this is run on python3 on macOS, it can invoke
|
| 180 |
+
python2 on macOS... but maybe this idea could be useful on linux, too.
|
| 181 |
+
'''
|
| 182 |
+
from inspect import getsourcefile
|
| 183 |
+
from os.path import abspath, exists
|
| 184 |
+
from subprocess import check_call
|
| 185 |
+
from threading import Thread
|
| 186 |
+
|
| 187 |
+
sound = _canonicalizePath(sound)
|
| 188 |
+
|
| 189 |
+
class PropogatingThread(Thread):
|
| 190 |
+
def run(self):
|
| 191 |
+
self.exc = None
|
| 192 |
+
try:
|
| 193 |
+
self.ret = self._target(*self._args, **self._kwargs)
|
| 194 |
+
except BaseException as e:
|
| 195 |
+
self.exc = e
|
| 196 |
+
|
| 197 |
+
def join(self, timeout = None):
|
| 198 |
+
super().join(timeout)
|
| 199 |
+
if self.exc:
|
| 200 |
+
raise self.exc
|
| 201 |
+
return self.ret
|
| 202 |
+
|
| 203 |
+
# Check if the file exists...
|
| 204 |
+
if not exists(abspath(sound)):
|
| 205 |
+
raise PlaysoundException('Cannot find a sound with filename: ' + sound)
|
| 206 |
+
|
| 207 |
+
playsoundPath = abspath(getsourcefile(lambda: 0))
|
| 208 |
+
t = PropogatingThread(target = lambda: check_call([otherPython, playsoundPath, _handlePathOSX(sound) if macOS else sound]))
|
| 209 |
+
t.start()
|
| 210 |
+
if block:
|
| 211 |
+
t.join()
|
| 212 |
+
|
| 213 |
+
from platform import system
|
| 214 |
+
system = system()
|
| 215 |
+
|
| 216 |
+
if system == 'Windows':
|
| 217 |
+
playsound_func = _playsoundWin
|
| 218 |
+
elif system == 'Darwin':
|
| 219 |
+
playsound_func = _playsoundOSX
|
| 220 |
+
import sys
|
| 221 |
+
if sys.version_info[0] > 2:
|
| 222 |
+
try:
|
| 223 |
+
from AppKit import NSSound
|
| 224 |
+
except ImportError:
|
| 225 |
+
logger.warning("playsound is relying on a python 2 subprocess. Please use `pip3 install PyObjC` if you want playsound to run more efficiently.")
|
| 226 |
+
playsound_func = lambda sound, block = True: _playsoundAnotherPython('/System/Library/Frameworks/Python.framework/Versions/2.7/bin/python', sound, block, macOS = True)
|
| 227 |
+
else:
|
| 228 |
+
playsound_func = _playsoundNix
|
| 229 |
+
if __name__ != '__main__': # Ensure we don't infinitely recurse trying to get another python instance.
|
| 230 |
+
try:
|
| 231 |
+
import gi
|
| 232 |
+
gi.require_version('Gst', '1.0')
|
| 233 |
+
from gi.repository import Gst
|
| 234 |
+
except:
|
| 235 |
+
logger.warning("playsound is relying on another python subprocess. Please use `pip install pygobject` if you want playsound to run more efficiently.")
|
| 236 |
+
playsound_func = lambda sound, block = True: _playsoundAnotherPython('/usr/bin/python3', sound, block, macOS = False)
|
| 237 |
+
|
| 238 |
+
del system
|
| 239 |
+
|
| 240 |
+
def play(audio_filepath):
|
| 241 |
+
playsound_func(audio_filepath)
|
audio_separator/separator/uvr_lib_v5/pyrb.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import tempfile
|
| 4 |
+
import six
|
| 5 |
+
import numpy as np
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
if getattr(sys, 'frozen', False):
|
| 10 |
+
BASE_PATH_RUB = sys._MEIPASS
|
| 11 |
+
else:
|
| 12 |
+
BASE_PATH_RUB = os.path.dirname(os.path.abspath(__file__))
|
| 13 |
+
|
| 14 |
+
__all__ = ['time_stretch', 'pitch_shift']
|
| 15 |
+
|
| 16 |
+
__RUBBERBAND_UTIL = os.path.join(BASE_PATH_RUB, 'rubberband')
|
| 17 |
+
|
| 18 |
+
if six.PY2:
|
| 19 |
+
DEVNULL = open(os.devnull, 'w')
|
| 20 |
+
else:
|
| 21 |
+
DEVNULL = subprocess.DEVNULL
|
| 22 |
+
|
| 23 |
+
def __rubberband(y, sr, **kwargs):
|
| 24 |
+
|
| 25 |
+
assert sr > 0
|
| 26 |
+
|
| 27 |
+
# Get the input and output tempfile
|
| 28 |
+
fd, infile = tempfile.mkstemp(suffix='.wav')
|
| 29 |
+
os.close(fd)
|
| 30 |
+
fd, outfile = tempfile.mkstemp(suffix='.wav')
|
| 31 |
+
os.close(fd)
|
| 32 |
+
|
| 33 |
+
# dump the audio
|
| 34 |
+
sf.write(infile, y, sr)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
# Execute rubberband
|
| 38 |
+
arguments = [__RUBBERBAND_UTIL, '-q']
|
| 39 |
+
|
| 40 |
+
for key, value in six.iteritems(kwargs):
|
| 41 |
+
arguments.append(str(key))
|
| 42 |
+
arguments.append(str(value))
|
| 43 |
+
|
| 44 |
+
arguments.extend([infile, outfile])
|
| 45 |
+
|
| 46 |
+
subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL)
|
| 47 |
+
|
| 48 |
+
# Load the processed audio.
|
| 49 |
+
y_out, _ = sf.read(outfile, always_2d=True)
|
| 50 |
+
|
| 51 |
+
# make sure that output dimensions matches input
|
| 52 |
+
if y.ndim == 1:
|
| 53 |
+
y_out = np.squeeze(y_out)
|
| 54 |
+
|
| 55 |
+
except OSError as exc:
|
| 56 |
+
six.raise_from(RuntimeError('Failed to execute rubberband. '
|
| 57 |
+
'Please verify that rubberband-cli '
|
| 58 |
+
'is installed.'),
|
| 59 |
+
exc)
|
| 60 |
+
|
| 61 |
+
finally:
|
| 62 |
+
# Remove temp files
|
| 63 |
+
os.unlink(infile)
|
| 64 |
+
os.unlink(outfile)
|
| 65 |
+
|
| 66 |
+
return y_out
|
| 67 |
+
|
| 68 |
+
def time_stretch(y, sr, rate, rbargs=None):
|
| 69 |
+
if rate <= 0:
|
| 70 |
+
raise ValueError('rate must be strictly positive')
|
| 71 |
+
|
| 72 |
+
if rate == 1.0:
|
| 73 |
+
return y
|
| 74 |
+
|
| 75 |
+
if rbargs is None:
|
| 76 |
+
rbargs = dict()
|
| 77 |
+
|
| 78 |
+
rbargs.setdefault('--tempo', rate)
|
| 79 |
+
|
| 80 |
+
return __rubberband(y, sr, **rbargs)
|
| 81 |
+
|
| 82 |
+
def pitch_shift(y, sr, n_steps, rbargs=None):
|
| 83 |
+
|
| 84 |
+
if n_steps == 0:
|
| 85 |
+
return y
|
| 86 |
+
|
| 87 |
+
if rbargs is None:
|
| 88 |
+
rbargs = dict()
|
| 89 |
+
|
| 90 |
+
rbargs.setdefault('--pitch', n_steps)
|
| 91 |
+
|
| 92 |
+
return __rubberband(y, sr, **rbargs)
|
audio_separator/separator/uvr_lib_v5/results.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Matchering - Audio Matching and Mastering Python Library
|
| 5 |
+
Copyright (C) 2016-2022 Sergree
|
| 6 |
+
|
| 7 |
+
This program is free software: you can redistribute it and/or modify
|
| 8 |
+
it under the terms of the GNU General Public License as published by
|
| 9 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 10 |
+
(at your option) any later version.
|
| 11 |
+
|
| 12 |
+
This program is distributed in the hope that it will be useful,
|
| 13 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 14 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 15 |
+
GNU General Public License for more details.
|
| 16 |
+
|
| 17 |
+
You should have received a copy of the GNU General Public License
|
| 18 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import soundfile as sf
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Result:
|
| 26 |
+
def __init__(
|
| 27 |
+
self, file: str, subtype: str, use_limiter: bool = True, normalize: bool = True
|
| 28 |
+
):
|
| 29 |
+
_, file_ext = os.path.splitext(file)
|
| 30 |
+
file_ext = file_ext[1:].upper()
|
| 31 |
+
if not sf.check_format(file_ext):
|
| 32 |
+
raise TypeError(f"{file_ext} format is not supported")
|
| 33 |
+
if not sf.check_format(file_ext, subtype):
|
| 34 |
+
raise TypeError(f"{file_ext} format does not have {subtype} subtype")
|
| 35 |
+
self.file = file
|
| 36 |
+
self.subtype = subtype
|
| 37 |
+
self.use_limiter = use_limiter
|
| 38 |
+
self.normalize = normalize
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def pcm16(file: str) -> Result:
|
| 42 |
+
return Result(file, "PCM_16")
|
| 43 |
+
|
| 44 |
+
def pcm24(file: str) -> Result:
|
| 45 |
+
return Result(file, "FLOAT")
|
| 46 |
+
|
| 47 |
+
def save_audiofile(file: str, wav_set="PCM_16") -> Result:
|
| 48 |
+
return Result(file, wav_set)
|
audio_separator/separator/uvr_lib_v5/roformer/attend.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import wraps
|
| 2 |
+
from packaging import version
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn, einsum
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from einops import rearrange, reduce
|
| 10 |
+
|
| 11 |
+
# constants
|
| 12 |
+
|
| 13 |
+
FlashAttentionConfig = namedtuple("FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
|
| 14 |
+
|
| 15 |
+
# helpers
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def exists(val):
|
| 19 |
+
return val is not None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def once(fn):
|
| 23 |
+
called = False
|
| 24 |
+
|
| 25 |
+
@wraps(fn)
|
| 26 |
+
def inner(x):
|
| 27 |
+
nonlocal called
|
| 28 |
+
if called:
|
| 29 |
+
return
|
| 30 |
+
called = True
|
| 31 |
+
return fn(x)
|
| 32 |
+
|
| 33 |
+
return inner
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
print_once = once(print)
|
| 37 |
+
|
| 38 |
+
# main class
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Attend(nn.Module):
|
| 42 |
+
def __init__(self, dropout=0.0, flash=False):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.dropout = dropout
|
| 45 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 46 |
+
|
| 47 |
+
self.flash = flash
|
| 48 |
+
assert not (flash and version.parse(torch.__version__) < version.parse("2.0.0")), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
| 49 |
+
|
| 50 |
+
# determine efficient attention configs for cuda and cpu
|
| 51 |
+
|
| 52 |
+
self.cpu_config = FlashAttentionConfig(True, True, True)
|
| 53 |
+
self.cuda_config = None
|
| 54 |
+
|
| 55 |
+
if not torch.cuda.is_available() or not flash:
|
| 56 |
+
return
|
| 57 |
+
|
| 58 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
| 59 |
+
|
| 60 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
| 61 |
+
print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
|
| 62 |
+
self.cuda_config = FlashAttentionConfig(True, False, False)
|
| 63 |
+
else:
|
| 64 |
+
self.cuda_config = FlashAttentionConfig(False, True, True)
|
| 65 |
+
|
| 66 |
+
def flash_attn(self, q, k, v):
|
| 67 |
+
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
| 68 |
+
|
| 69 |
+
# Check if there is a compatible device for flash attention
|
| 70 |
+
|
| 71 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
| 72 |
+
|
| 73 |
+
# sdpa_flash kernel only supports float16 on sm80+ architecture gpu
|
| 74 |
+
if is_cuda and q.dtype != torch.float16:
|
| 75 |
+
config = FlashAttentionConfig(False, True, True)
|
| 76 |
+
|
| 77 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
|
| 78 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
| 79 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0)
|
| 80 |
+
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
def forward(self, q, k, v):
|
| 84 |
+
"""
|
| 85 |
+
einstein notation
|
| 86 |
+
b - batch
|
| 87 |
+
h - heads
|
| 88 |
+
n, i, j - sequence length (base sequence length, source, target)
|
| 89 |
+
d - feature dimension
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
| 93 |
+
|
| 94 |
+
scale = q.shape[-1] ** -0.5
|
| 95 |
+
|
| 96 |
+
if self.flash:
|
| 97 |
+
return self.flash_attn(q, k, v)
|
| 98 |
+
|
| 99 |
+
# similarity
|
| 100 |
+
|
| 101 |
+
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
|
| 102 |
+
|
| 103 |
+
# attention
|
| 104 |
+
|
| 105 |
+
attn = sim.softmax(dim=-1)
|
| 106 |
+
attn = self.attn_dropout(attn)
|
| 107 |
+
|
| 108 |
+
# aggregate values
|
| 109 |
+
|
| 110 |
+
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
| 111 |
+
|
| 112 |
+
return out
|
audio_separator/separator/uvr_lib_v5/roformer/bs_roformer.py
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, einsum, Tensor
|
| 5 |
+
from torch.nn import Module, ModuleList
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from .attend import Attend
|
| 9 |
+
|
| 10 |
+
from beartype.typing import Tuple, Optional, List, Callable
|
| 11 |
+
from beartype import beartype
|
| 12 |
+
|
| 13 |
+
from rotary_embedding_torch import RotaryEmbedding
|
| 14 |
+
|
| 15 |
+
from einops import rearrange, pack, unpack
|
| 16 |
+
from einops.layers.torch import Rearrange
|
| 17 |
+
|
| 18 |
+
# helper functions
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def exists(val):
|
| 22 |
+
return val is not None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def default(v, d):
|
| 26 |
+
return v if exists(v) else d
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def pack_one(t, pattern):
|
| 30 |
+
return pack([t], pattern)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def unpack_one(t, ps, pattern):
|
| 34 |
+
return unpack(t, ps, pattern)[0]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# norm
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def l2norm(t):
|
| 41 |
+
return F.normalize(t, dim=-1, p=2)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class RMSNorm(Module):
|
| 45 |
+
def __init__(self, dim):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.scale = dim**0.5
|
| 48 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
x = x.to(self.gamma.device)
|
| 52 |
+
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# attention
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class FeedForward(Module):
|
| 59 |
+
def __init__(self, dim, mult=4, dropout=0.0):
|
| 60 |
+
super().__init__()
|
| 61 |
+
dim_inner = int(dim * mult)
|
| 62 |
+
self.net = nn.Sequential(RMSNorm(dim), nn.Linear(dim, dim_inner), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim), nn.Dropout(dropout))
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
return self.net(x)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Attention(Module):
|
| 69 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.heads = heads
|
| 72 |
+
self.scale = dim_head**-0.5
|
| 73 |
+
dim_inner = heads * dim_head
|
| 74 |
+
|
| 75 |
+
self.rotary_embed = rotary_embed
|
| 76 |
+
|
| 77 |
+
self.attend = Attend(flash=flash, dropout=dropout)
|
| 78 |
+
|
| 79 |
+
self.norm = RMSNorm(dim)
|
| 80 |
+
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
| 81 |
+
|
| 82 |
+
self.to_gates = nn.Linear(dim, heads)
|
| 83 |
+
|
| 84 |
+
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
x = self.norm(x)
|
| 88 |
+
|
| 89 |
+
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
|
| 90 |
+
|
| 91 |
+
if exists(self.rotary_embed):
|
| 92 |
+
q = self.rotary_embed.rotate_queries_or_keys(q)
|
| 93 |
+
k = self.rotary_embed.rotate_queries_or_keys(k)
|
| 94 |
+
|
| 95 |
+
out = self.attend(q, k, v)
|
| 96 |
+
|
| 97 |
+
gates = self.to_gates(x)
|
| 98 |
+
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
| 99 |
+
|
| 100 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 101 |
+
return self.to_out(out)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class LinearAttention(Module):
|
| 105 |
+
"""
|
| 106 |
+
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
@beartype
|
| 110 |
+
def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
|
| 111 |
+
super().__init__()
|
| 112 |
+
dim_inner = dim_head * heads
|
| 113 |
+
self.norm = RMSNorm(dim)
|
| 114 |
+
|
| 115 |
+
self.to_qkv = nn.Sequential(nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads))
|
| 116 |
+
|
| 117 |
+
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
| 118 |
+
|
| 119 |
+
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
| 120 |
+
|
| 121 |
+
self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
x = self.norm(x)
|
| 125 |
+
|
| 126 |
+
q, k, v = self.to_qkv(x)
|
| 127 |
+
|
| 128 |
+
q, k = map(l2norm, (q, k))
|
| 129 |
+
q = q * self.temperature.exp()
|
| 130 |
+
|
| 131 |
+
out = self.attend(q, k, v)
|
| 132 |
+
|
| 133 |
+
return self.to_out(out)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class Transformer(Module):
|
| 137 |
+
def __init__(self, *, dim, depth, dim_head=64, heads=8, attn_dropout=0.0, ff_dropout=0.0, ff_mult=4, norm_output=True, rotary_embed=None, flash_attn=True, linear_attn=False):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.layers = ModuleList([])
|
| 140 |
+
|
| 141 |
+
for _ in range(depth):
|
| 142 |
+
if linear_attn:
|
| 143 |
+
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
|
| 144 |
+
else:
|
| 145 |
+
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed, flash=flash_attn)
|
| 146 |
+
|
| 147 |
+
self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]))
|
| 148 |
+
|
| 149 |
+
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
|
| 153 |
+
for attn, ff in self.layers:
|
| 154 |
+
x = attn(x) + x
|
| 155 |
+
x = ff(x) + x
|
| 156 |
+
|
| 157 |
+
return self.norm(x)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# bandsplit module
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class BandSplit(Module):
|
| 164 |
+
@beartype
|
| 165 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.dim_inputs = dim_inputs
|
| 168 |
+
self.to_features = ModuleList([])
|
| 169 |
+
|
| 170 |
+
for dim_in in dim_inputs:
|
| 171 |
+
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
| 172 |
+
|
| 173 |
+
self.to_features.append(net)
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
x = x.split(self.dim_inputs, dim=-1)
|
| 177 |
+
|
| 178 |
+
outs = []
|
| 179 |
+
for split_input, to_feature in zip(x, self.to_features):
|
| 180 |
+
split_output = to_feature(split_input)
|
| 181 |
+
outs.append(split_output)
|
| 182 |
+
|
| 183 |
+
return torch.stack(outs, dim=-2)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
| 187 |
+
dim_hidden = default(dim_hidden, dim_in)
|
| 188 |
+
|
| 189 |
+
net = []
|
| 190 |
+
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
|
| 191 |
+
|
| 192 |
+
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
| 193 |
+
is_last = ind == (len(dims) - 2)
|
| 194 |
+
|
| 195 |
+
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
| 196 |
+
|
| 197 |
+
if is_last:
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
net.append(activation())
|
| 201 |
+
|
| 202 |
+
return nn.Sequential(*net)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class MaskEstimator(Module):
|
| 206 |
+
@beartype
|
| 207 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.dim_inputs = dim_inputs
|
| 210 |
+
self.to_freqs = ModuleList([])
|
| 211 |
+
dim_hidden = dim * mlp_expansion_factor
|
| 212 |
+
|
| 213 |
+
for dim_in in dim_inputs:
|
| 214 |
+
net = []
|
| 215 |
+
|
| 216 |
+
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
|
| 217 |
+
|
| 218 |
+
self.to_freqs.append(mlp)
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
x = x.unbind(dim=-2)
|
| 222 |
+
|
| 223 |
+
outs = []
|
| 224 |
+
|
| 225 |
+
for band_features, mlp in zip(x, self.to_freqs):
|
| 226 |
+
freq_out = mlp(band_features)
|
| 227 |
+
outs.append(freq_out)
|
| 228 |
+
|
| 229 |
+
return torch.cat(outs, dim=-1)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# main class
|
| 233 |
+
|
| 234 |
+
DEFAULT_FREQS_PER_BANDS = (
|
| 235 |
+
2,
|
| 236 |
+
2,
|
| 237 |
+
2,
|
| 238 |
+
2,
|
| 239 |
+
2,
|
| 240 |
+
2,
|
| 241 |
+
2,
|
| 242 |
+
2,
|
| 243 |
+
2,
|
| 244 |
+
2,
|
| 245 |
+
2,
|
| 246 |
+
2,
|
| 247 |
+
2,
|
| 248 |
+
2,
|
| 249 |
+
2,
|
| 250 |
+
2,
|
| 251 |
+
2,
|
| 252 |
+
2,
|
| 253 |
+
2,
|
| 254 |
+
2,
|
| 255 |
+
2,
|
| 256 |
+
2,
|
| 257 |
+
2,
|
| 258 |
+
2,
|
| 259 |
+
4,
|
| 260 |
+
4,
|
| 261 |
+
4,
|
| 262 |
+
4,
|
| 263 |
+
4,
|
| 264 |
+
4,
|
| 265 |
+
4,
|
| 266 |
+
4,
|
| 267 |
+
4,
|
| 268 |
+
4,
|
| 269 |
+
4,
|
| 270 |
+
4,
|
| 271 |
+
12,
|
| 272 |
+
12,
|
| 273 |
+
12,
|
| 274 |
+
12,
|
| 275 |
+
12,
|
| 276 |
+
12,
|
| 277 |
+
12,
|
| 278 |
+
12,
|
| 279 |
+
24,
|
| 280 |
+
24,
|
| 281 |
+
24,
|
| 282 |
+
24,
|
| 283 |
+
24,
|
| 284 |
+
24,
|
| 285 |
+
24,
|
| 286 |
+
24,
|
| 287 |
+
48,
|
| 288 |
+
48,
|
| 289 |
+
48,
|
| 290 |
+
48,
|
| 291 |
+
48,
|
| 292 |
+
48,
|
| 293 |
+
48,
|
| 294 |
+
48,
|
| 295 |
+
128,
|
| 296 |
+
129,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class BSRoformer(Module):
|
| 301 |
+
|
| 302 |
+
@beartype
|
| 303 |
+
def __init__(
|
| 304 |
+
self,
|
| 305 |
+
dim,
|
| 306 |
+
*,
|
| 307 |
+
depth,
|
| 308 |
+
stereo=False,
|
| 309 |
+
num_stems=1,
|
| 310 |
+
time_transformer_depth=2,
|
| 311 |
+
freq_transformer_depth=2,
|
| 312 |
+
linear_transformer_depth=0,
|
| 313 |
+
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
| 314 |
+
# in the paper, they divide into ~60 bands, test with 1 for starters
|
| 315 |
+
dim_head=64,
|
| 316 |
+
heads=8,
|
| 317 |
+
attn_dropout=0.0,
|
| 318 |
+
ff_dropout=0.0,
|
| 319 |
+
flash_attn=True,
|
| 320 |
+
dim_freqs_in=1025,
|
| 321 |
+
stft_n_fft=2048,
|
| 322 |
+
stft_hop_length=512,
|
| 323 |
+
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
| 324 |
+
stft_win_length=2048,
|
| 325 |
+
stft_normalized=False,
|
| 326 |
+
stft_window_fn: Optional[Callable] = None,
|
| 327 |
+
mask_estimator_depth=2,
|
| 328 |
+
multi_stft_resolution_loss_weight=1.0,
|
| 329 |
+
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
| 330 |
+
multi_stft_hop_size=147,
|
| 331 |
+
multi_stft_normalized=False,
|
| 332 |
+
multi_stft_window_fn: Callable = torch.hann_window,
|
| 333 |
+
):
|
| 334 |
+
super().__init__()
|
| 335 |
+
|
| 336 |
+
self.stereo = stereo
|
| 337 |
+
self.audio_channels = 2 if stereo else 1
|
| 338 |
+
self.num_stems = num_stems
|
| 339 |
+
|
| 340 |
+
self.layers = ModuleList([])
|
| 341 |
+
|
| 342 |
+
transformer_kwargs = dict(dim=dim, heads=heads, dim_head=dim_head, attn_dropout=attn_dropout, ff_dropout=ff_dropout, flash_attn=flash_attn, norm_output=False)
|
| 343 |
+
|
| 344 |
+
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 345 |
+
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 346 |
+
|
| 347 |
+
for _ in range(depth):
|
| 348 |
+
tran_modules = []
|
| 349 |
+
if linear_transformer_depth > 0:
|
| 350 |
+
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
|
| 351 |
+
tran_modules.append(Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs))
|
| 352 |
+
tran_modules.append(Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs))
|
| 353 |
+
self.layers.append(nn.ModuleList(tran_modules))
|
| 354 |
+
|
| 355 |
+
self.final_norm = RMSNorm(dim)
|
| 356 |
+
|
| 357 |
+
self.stft_kwargs = dict(n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized)
|
| 358 |
+
|
| 359 |
+
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
| 360 |
+
|
| 361 |
+
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
|
| 362 |
+
|
| 363 |
+
assert len(freqs_per_bands) > 1
|
| 364 |
+
assert sum(freqs_per_bands) == freqs, f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
|
| 365 |
+
|
| 366 |
+
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
|
| 367 |
+
|
| 368 |
+
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
| 369 |
+
|
| 370 |
+
self.mask_estimators = nn.ModuleList([])
|
| 371 |
+
|
| 372 |
+
for _ in range(num_stems):
|
| 373 |
+
mask_estimator = MaskEstimator(dim=dim, dim_inputs=freqs_per_bands_with_complex, depth=mask_estimator_depth)
|
| 374 |
+
|
| 375 |
+
self.mask_estimators.append(mask_estimator)
|
| 376 |
+
|
| 377 |
+
# for the multi-resolution stft loss
|
| 378 |
+
|
| 379 |
+
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
| 380 |
+
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
| 381 |
+
self.multi_stft_n_fft = stft_n_fft
|
| 382 |
+
self.multi_stft_window_fn = multi_stft_window_fn
|
| 383 |
+
|
| 384 |
+
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
|
| 385 |
+
|
| 386 |
+
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
| 387 |
+
"""
|
| 388 |
+
einops
|
| 389 |
+
|
| 390 |
+
b - batch
|
| 391 |
+
f - freq
|
| 392 |
+
t - time
|
| 393 |
+
s - audio channel (1 for mono, 2 for stereo)
|
| 394 |
+
n - number of 'stems'
|
| 395 |
+
c - complex (2)
|
| 396 |
+
d - feature dimension
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
original_device = raw_audio.device
|
| 400 |
+
x_is_mps = True if original_device.type == "mps" else False
|
| 401 |
+
|
| 402 |
+
# if x_is_mps:
|
| 403 |
+
# raw_audio = raw_audio.cpu()
|
| 404 |
+
|
| 405 |
+
device = raw_audio.device
|
| 406 |
+
|
| 407 |
+
if raw_audio.ndim == 2:
|
| 408 |
+
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
| 409 |
+
|
| 410 |
+
channels = raw_audio.shape[1]
|
| 411 |
+
assert (not self.stereo and channels == 1) or (
|
| 412 |
+
self.stereo and channels == 2
|
| 413 |
+
), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
| 414 |
+
|
| 415 |
+
# to stft
|
| 416 |
+
|
| 417 |
+
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
| 418 |
+
|
| 419 |
+
stft_window = self.stft_window_fn().to(device)
|
| 420 |
+
|
| 421 |
+
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
| 422 |
+
stft_repr = torch.view_as_real(stft_repr)
|
| 423 |
+
|
| 424 |
+
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
| 425 |
+
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c") # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
| 426 |
+
|
| 427 |
+
x = rearrange(stft_repr, "b f t c -> b t (f c)")
|
| 428 |
+
|
| 429 |
+
x = self.band_split(x)
|
| 430 |
+
|
| 431 |
+
# axial / hierarchical attention
|
| 432 |
+
|
| 433 |
+
for transformer_block in self.layers:
|
| 434 |
+
|
| 435 |
+
if len(transformer_block) == 3:
|
| 436 |
+
linear_transformer, time_transformer, freq_transformer = transformer_block
|
| 437 |
+
|
| 438 |
+
x, ft_ps = pack([x], "b * d")
|
| 439 |
+
x = linear_transformer(x)
|
| 440 |
+
(x,) = unpack(x, ft_ps, "b * d")
|
| 441 |
+
else:
|
| 442 |
+
time_transformer, freq_transformer = transformer_block
|
| 443 |
+
|
| 444 |
+
x = rearrange(x, "b t f d -> b f t d")
|
| 445 |
+
x, ps = pack([x], "* t d")
|
| 446 |
+
|
| 447 |
+
x = time_transformer(x)
|
| 448 |
+
|
| 449 |
+
(x,) = unpack(x, ps, "* t d")
|
| 450 |
+
x = rearrange(x, "b f t d -> b t f d")
|
| 451 |
+
x, ps = pack([x], "* f d")
|
| 452 |
+
|
| 453 |
+
x = freq_transformer(x)
|
| 454 |
+
|
| 455 |
+
(x,) = unpack(x, ps, "* f d")
|
| 456 |
+
|
| 457 |
+
x = self.final_norm(x)
|
| 458 |
+
|
| 459 |
+
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
| 460 |
+
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
|
| 461 |
+
|
| 462 |
+
# if x_is_mps:
|
| 463 |
+
# mask = mask.to('cpu')
|
| 464 |
+
|
| 465 |
+
# modulate frequency representation
|
| 466 |
+
|
| 467 |
+
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
| 468 |
+
|
| 469 |
+
# complex number multiplication
|
| 470 |
+
|
| 471 |
+
stft_repr = torch.view_as_complex(stft_repr)
|
| 472 |
+
mask = torch.view_as_complex(mask)
|
| 473 |
+
|
| 474 |
+
stft_repr = stft_repr * mask
|
| 475 |
+
|
| 476 |
+
# istft
|
| 477 |
+
|
| 478 |
+
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
|
| 479 |
+
|
| 480 |
+
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False).to(device)
|
| 481 |
+
|
| 482 |
+
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=self.num_stems)
|
| 483 |
+
|
| 484 |
+
if self.num_stems == 1:
|
| 485 |
+
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
| 486 |
+
|
| 487 |
+
# if a target is passed in, calculate loss for learning
|
| 488 |
+
|
| 489 |
+
if not exists(target):
|
| 490 |
+
return recon_audio
|
| 491 |
+
|
| 492 |
+
if self.num_stems > 1:
|
| 493 |
+
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
| 494 |
+
|
| 495 |
+
if target.ndim == 2:
|
| 496 |
+
target = rearrange(target, "... t -> ... 1 t")
|
| 497 |
+
|
| 498 |
+
target = target[..., : recon_audio.shape[-1]]
|
| 499 |
+
|
| 500 |
+
loss = F.l1_loss(recon_audio, target)
|
| 501 |
+
|
| 502 |
+
multi_stft_resolution_loss = 0.0
|
| 503 |
+
|
| 504 |
+
for window_size in self.multi_stft_resolutions_window_sizes:
|
| 505 |
+
res_stft_kwargs = dict(
|
| 506 |
+
n_fft=max(window_size, self.multi_stft_n_fft), win_length=window_size, return_complex=True, window=self.multi_stft_window_fn(window_size, device=device), **self.multi_stft_kwargs
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
|
| 510 |
+
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
|
| 511 |
+
|
| 512 |
+
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
| 513 |
+
|
| 514 |
+
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
| 515 |
+
|
| 516 |
+
total_loss = loss + weighted_multi_resolution_loss
|
| 517 |
+
|
| 518 |
+
if not return_loss_breakdown:
|
| 519 |
+
# Move the result back to the original device if it was moved to CPU for MPS compatibility
|
| 520 |
+
# if x_is_mps:
|
| 521 |
+
# total_loss = total_loss.to(original_device)
|
| 522 |
+
return total_loss
|
| 523 |
+
|
| 524 |
+
# For detailed loss breakdown, ensure all components are moved back to the original device for MPS
|
| 525 |
+
# if x_is_mps:
|
| 526 |
+
# loss = loss.to(original_device)
|
| 527 |
+
# multi_stft_resolution_loss = multi_stft_resolution_loss.to(original_device)
|
| 528 |
+
# weighted_multi_resolution_loss = weighted_multi_resolution_loss.to(original_device)
|
| 529 |
+
|
| 530 |
+
return total_loss, (loss, multi_stft_resolution_loss)
|
| 531 |
+
|
| 532 |
+
# if not return_loss_breakdown:
|
| 533 |
+
# return total_loss
|
| 534 |
+
|
| 535 |
+
# return total_loss, (loss, multi_stft_resolution_loss)
|
audio_separator/separator/uvr_lib_v5/roformer/mel_band_roformer.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, einsum, Tensor
|
| 5 |
+
from torch.nn import Module, ModuleList
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from .attend import Attend
|
| 9 |
+
|
| 10 |
+
from beartype.typing import Tuple, Optional, List, Callable
|
| 11 |
+
from beartype import beartype
|
| 12 |
+
|
| 13 |
+
from rotary_embedding_torch import RotaryEmbedding
|
| 14 |
+
|
| 15 |
+
from einops import rearrange, pack, unpack, reduce, repeat
|
| 16 |
+
|
| 17 |
+
from librosa import filters
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def exists(val):
|
| 21 |
+
return val is not None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def default(v, d):
|
| 25 |
+
return v if exists(v) else d
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def pack_one(t, pattern):
|
| 29 |
+
return pack([t], pattern)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def unpack_one(t, ps, pattern):
|
| 33 |
+
return unpack(t, ps, pattern)[0]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def pad_at_dim(t, pad, dim=-1, value=0.0):
|
| 37 |
+
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
| 38 |
+
zeros = (0, 0) * dims_from_right
|
| 39 |
+
return F.pad(t, (*zeros, *pad), value=value)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class RMSNorm(Module):
|
| 43 |
+
def __init__(self, dim):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.scale = dim**0.5
|
| 46 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = x.to(self.gamma.device)
|
| 50 |
+
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class FeedForward(Module):
|
| 54 |
+
def __init__(self, dim, mult=4, dropout=0.0):
|
| 55 |
+
super().__init__()
|
| 56 |
+
dim_inner = int(dim * mult)
|
| 57 |
+
self.net = nn.Sequential(RMSNorm(dim), nn.Linear(dim, dim_inner), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim), nn.Dropout(dropout))
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
return self.net(x)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Attention(Module):
|
| 64 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.heads = heads
|
| 67 |
+
self.scale = dim_head**-0.5
|
| 68 |
+
dim_inner = heads * dim_head
|
| 69 |
+
|
| 70 |
+
self.rotary_embed = rotary_embed
|
| 71 |
+
|
| 72 |
+
self.attend = Attend(flash=flash, dropout=dropout)
|
| 73 |
+
|
| 74 |
+
self.norm = RMSNorm(dim)
|
| 75 |
+
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
| 76 |
+
|
| 77 |
+
self.to_gates = nn.Linear(dim, heads)
|
| 78 |
+
|
| 79 |
+
self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout))
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
x = self.norm(x)
|
| 83 |
+
|
| 84 |
+
q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
|
| 85 |
+
|
| 86 |
+
if exists(self.rotary_embed):
|
| 87 |
+
q = self.rotary_embed.rotate_queries_or_keys(q)
|
| 88 |
+
k = self.rotary_embed.rotate_queries_or_keys(k)
|
| 89 |
+
|
| 90 |
+
out = self.attend(q, k, v)
|
| 91 |
+
|
| 92 |
+
gates = self.to_gates(x)
|
| 93 |
+
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
| 94 |
+
|
| 95 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 96 |
+
return self.to_out(out)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Transformer(Module):
|
| 100 |
+
def __init__(self, *, dim, depth, dim_head=64, heads=8, attn_dropout=0.0, ff_dropout=0.0, ff_mult=4, norm_output=True, rotary_embed=None, flash_attn=True):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.layers = ModuleList([])
|
| 103 |
+
|
| 104 |
+
for _ in range(depth):
|
| 105 |
+
self.layers.append(
|
| 106 |
+
ModuleList(
|
| 107 |
+
[Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed, flash=flash_attn), FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
|
| 115 |
+
for attn, ff in self.layers:
|
| 116 |
+
x = attn(x) + x
|
| 117 |
+
x = ff(x) + x
|
| 118 |
+
|
| 119 |
+
return self.norm(x)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class BandSplit(Module):
|
| 123 |
+
@beartype
|
| 124 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.dim_inputs = dim_inputs
|
| 127 |
+
self.to_features = ModuleList([])
|
| 128 |
+
|
| 129 |
+
for dim_in in dim_inputs:
|
| 130 |
+
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
| 131 |
+
|
| 132 |
+
self.to_features.append(net)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
x = x.split(self.dim_inputs, dim=-1)
|
| 136 |
+
|
| 137 |
+
outs = []
|
| 138 |
+
for split_input, to_feature in zip(x, self.to_features):
|
| 139 |
+
split_output = to_feature(split_input)
|
| 140 |
+
outs.append(split_output)
|
| 141 |
+
|
| 142 |
+
return torch.stack(outs, dim=-2)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
| 146 |
+
dim_hidden = default(dim_hidden, dim_in)
|
| 147 |
+
|
| 148 |
+
net = []
|
| 149 |
+
dims = (dim_in, *((dim_hidden,) * depth), dim_out)
|
| 150 |
+
|
| 151 |
+
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
| 152 |
+
is_last = ind == (len(dims) - 2)
|
| 153 |
+
|
| 154 |
+
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
| 155 |
+
|
| 156 |
+
if is_last:
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
net.append(activation())
|
| 160 |
+
|
| 161 |
+
return nn.Sequential(*net)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class MaskEstimator(Module):
|
| 165 |
+
@beartype
|
| 166 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.dim_inputs = dim_inputs
|
| 169 |
+
self.to_freqs = ModuleList([])
|
| 170 |
+
dim_hidden = dim * mlp_expansion_factor
|
| 171 |
+
|
| 172 |
+
for dim_in in dim_inputs:
|
| 173 |
+
net = []
|
| 174 |
+
|
| 175 |
+
mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1))
|
| 176 |
+
|
| 177 |
+
self.to_freqs.append(mlp)
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
x = x.unbind(dim=-2)
|
| 181 |
+
|
| 182 |
+
outs = []
|
| 183 |
+
|
| 184 |
+
for band_features, mlp in zip(x, self.to_freqs):
|
| 185 |
+
freq_out = mlp(band_features)
|
| 186 |
+
outs.append(freq_out)
|
| 187 |
+
|
| 188 |
+
return torch.cat(outs, dim=-1)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class MelBandRoformer(Module):
|
| 192 |
+
|
| 193 |
+
@beartype
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
dim,
|
| 197 |
+
*,
|
| 198 |
+
depth,
|
| 199 |
+
stereo=False,
|
| 200 |
+
num_stems=1,
|
| 201 |
+
time_transformer_depth=2,
|
| 202 |
+
freq_transformer_depth=2,
|
| 203 |
+
num_bands=60,
|
| 204 |
+
dim_head=64,
|
| 205 |
+
heads=8,
|
| 206 |
+
attn_dropout=0.1,
|
| 207 |
+
ff_dropout=0.1,
|
| 208 |
+
flash_attn=True,
|
| 209 |
+
dim_freqs_in=1025,
|
| 210 |
+
sample_rate=44100,
|
| 211 |
+
stft_n_fft=2048,
|
| 212 |
+
stft_hop_length=512,
|
| 213 |
+
stft_win_length=2048,
|
| 214 |
+
stft_normalized=False,
|
| 215 |
+
stft_window_fn: Optional[Callable] = None,
|
| 216 |
+
mask_estimator_depth=1,
|
| 217 |
+
multi_stft_resolution_loss_weight=1.0,
|
| 218 |
+
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
| 219 |
+
multi_stft_hop_size=147,
|
| 220 |
+
multi_stft_normalized=False,
|
| 221 |
+
multi_stft_window_fn: Callable = torch.hann_window,
|
| 222 |
+
match_input_audio_length=False,
|
| 223 |
+
):
|
| 224 |
+
super().__init__()
|
| 225 |
+
|
| 226 |
+
self.stereo = stereo
|
| 227 |
+
self.audio_channels = 2 if stereo else 1
|
| 228 |
+
self.num_stems = num_stems
|
| 229 |
+
|
| 230 |
+
self.layers = ModuleList([])
|
| 231 |
+
|
| 232 |
+
transformer_kwargs = dict(dim=dim, heads=heads, dim_head=dim_head, attn_dropout=attn_dropout, ff_dropout=ff_dropout, flash_attn=flash_attn)
|
| 233 |
+
|
| 234 |
+
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 235 |
+
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 236 |
+
|
| 237 |
+
for _ in range(depth):
|
| 238 |
+
self.layers.append(
|
| 239 |
+
nn.ModuleList(
|
| 240 |
+
[
|
| 241 |
+
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs),
|
| 242 |
+
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs),
|
| 243 |
+
]
|
| 244 |
+
)
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
| 248 |
+
|
| 249 |
+
self.stft_kwargs = dict(n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized)
|
| 250 |
+
|
| 251 |
+
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
|
| 252 |
+
|
| 253 |
+
mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
|
| 254 |
+
|
| 255 |
+
mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
|
| 256 |
+
|
| 257 |
+
mel_filter_bank[0][0] = 1.0
|
| 258 |
+
|
| 259 |
+
mel_filter_bank[-1, -1] = 1.0
|
| 260 |
+
|
| 261 |
+
freqs_per_band = mel_filter_bank > 0
|
| 262 |
+
assert freqs_per_band.any(dim=0).all(), "all frequencies need to be covered by all bands for now"
|
| 263 |
+
|
| 264 |
+
repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
|
| 265 |
+
freq_indices = repeated_freq_indices[freqs_per_band]
|
| 266 |
+
|
| 267 |
+
if stereo:
|
| 268 |
+
freq_indices = repeat(freq_indices, "f -> f s", s=2)
|
| 269 |
+
freq_indices = freq_indices * 2 + torch.arange(2)
|
| 270 |
+
freq_indices = rearrange(freq_indices, "f s -> (f s)")
|
| 271 |
+
|
| 272 |
+
self.register_buffer("freq_indices", freq_indices, persistent=False)
|
| 273 |
+
self.register_buffer("freqs_per_band", freqs_per_band, persistent=False)
|
| 274 |
+
|
| 275 |
+
num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
|
| 276 |
+
num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
|
| 277 |
+
|
| 278 |
+
self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
|
| 279 |
+
self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)
|
| 280 |
+
|
| 281 |
+
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
|
| 282 |
+
|
| 283 |
+
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
| 284 |
+
|
| 285 |
+
self.mask_estimators = nn.ModuleList([])
|
| 286 |
+
|
| 287 |
+
for _ in range(num_stems):
|
| 288 |
+
mask_estimator = MaskEstimator(dim=dim, dim_inputs=freqs_per_bands_with_complex, depth=mask_estimator_depth)
|
| 289 |
+
|
| 290 |
+
self.mask_estimators.append(mask_estimator)
|
| 291 |
+
|
| 292 |
+
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
| 293 |
+
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
| 294 |
+
self.multi_stft_n_fft = stft_n_fft
|
| 295 |
+
self.multi_stft_window_fn = multi_stft_window_fn
|
| 296 |
+
|
| 297 |
+
self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized)
|
| 298 |
+
|
| 299 |
+
self.match_input_audio_length = match_input_audio_length
|
| 300 |
+
|
| 301 |
+
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
| 302 |
+
"""
|
| 303 |
+
einops
|
| 304 |
+
|
| 305 |
+
b - batch
|
| 306 |
+
f - freq
|
| 307 |
+
t - time
|
| 308 |
+
s - audio channel (1 for mono, 2 for stereo)
|
| 309 |
+
n - number of 'stems'
|
| 310 |
+
c - complex (2)
|
| 311 |
+
d - feature dimension
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
original_device = raw_audio.device
|
| 315 |
+
x_is_mps = True if original_device.type == "mps" else False
|
| 316 |
+
|
| 317 |
+
if x_is_mps:
|
| 318 |
+
raw_audio = raw_audio.cpu()
|
| 319 |
+
|
| 320 |
+
device = raw_audio.device
|
| 321 |
+
|
| 322 |
+
if raw_audio.ndim == 2:
|
| 323 |
+
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
| 324 |
+
|
| 325 |
+
batch, channels, raw_audio_length = raw_audio.shape
|
| 326 |
+
|
| 327 |
+
istft_length = raw_audio_length if self.match_input_audio_length else None
|
| 328 |
+
|
| 329 |
+
assert (not self.stereo and channels == 1) or (
|
| 330 |
+
self.stereo and channels == 2
|
| 331 |
+
), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
| 332 |
+
|
| 333 |
+
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
| 334 |
+
|
| 335 |
+
stft_window = self.stft_window_fn().to(device)
|
| 336 |
+
|
| 337 |
+
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
| 338 |
+
stft_repr = torch.view_as_real(stft_repr)
|
| 339 |
+
|
| 340 |
+
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
| 341 |
+
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c") # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
| 342 |
+
|
| 343 |
+
batch_arange = torch.arange(batch, device=device)[..., None]
|
| 344 |
+
|
| 345 |
+
x = stft_repr[batch_arange, self.freq_indices.cpu()] if x_is_mps else stft_repr[batch_arange, self.freq_indices]
|
| 346 |
+
|
| 347 |
+
x = rearrange(x, "b f t c -> b t (f c)")
|
| 348 |
+
|
| 349 |
+
x = self.band_split(x)
|
| 350 |
+
|
| 351 |
+
for time_transformer, freq_transformer in self.layers:
|
| 352 |
+
x = rearrange(x, "b t f d -> b f t d")
|
| 353 |
+
x, ps = pack([x], "* t d")
|
| 354 |
+
|
| 355 |
+
x = time_transformer(x)
|
| 356 |
+
|
| 357 |
+
(x,) = unpack(x, ps, "* t d")
|
| 358 |
+
x = rearrange(x, "b f t d -> b t f d")
|
| 359 |
+
x, ps = pack([x], "* f d")
|
| 360 |
+
|
| 361 |
+
x = freq_transformer(x)
|
| 362 |
+
|
| 363 |
+
(x,) = unpack(x, ps, "* f d")
|
| 364 |
+
|
| 365 |
+
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
| 366 |
+
masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
|
| 367 |
+
|
| 368 |
+
if x_is_mps:
|
| 369 |
+
masks = masks.cpu()
|
| 370 |
+
|
| 371 |
+
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
| 372 |
+
|
| 373 |
+
stft_repr = torch.view_as_complex(stft_repr)
|
| 374 |
+
masks = torch.view_as_complex(masks)
|
| 375 |
+
|
| 376 |
+
masks = masks.type(stft_repr.dtype)
|
| 377 |
+
|
| 378 |
+
if x_is_mps:
|
| 379 |
+
scatter_indices = repeat(self.freq_indices.cpu(), "f -> b n f t", b=batch, n=self.num_stems, t=stft_repr.shape[-1])
|
| 380 |
+
else:
|
| 381 |
+
scatter_indices = repeat(self.freq_indices, "f -> b n f t", b=batch, n=self.num_stems, t=stft_repr.shape[-1])
|
| 382 |
+
|
| 383 |
+
stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=self.num_stems)
|
| 384 |
+
masks_summed = (
|
| 385 |
+
torch.zeros_like(stft_repr_expanded_stems.cpu() if x_is_mps else stft_repr_expanded_stems)
|
| 386 |
+
.scatter_add_(2, scatter_indices.cpu() if x_is_mps else scatter_indices, masks.cpu() if x_is_mps else masks)
|
| 387 |
+
.to(device)
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels)
|
| 391 |
+
|
| 392 |
+
if x_is_mps:
|
| 393 |
+
denom = denom.cpu()
|
| 394 |
+
|
| 395 |
+
masks_averaged = masks_summed / denom.clamp(min=1e-8)
|
| 396 |
+
|
| 397 |
+
stft_repr = stft_repr * masks_averaged
|
| 398 |
+
|
| 399 |
+
stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)
|
| 400 |
+
|
| 401 |
+
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=istft_length)
|
| 402 |
+
|
| 403 |
+
recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", b=batch, s=self.audio_channels, n=self.num_stems)
|
| 404 |
+
|
| 405 |
+
if self.num_stems == 1:
|
| 406 |
+
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
| 407 |
+
|
| 408 |
+
if not exists(target):
|
| 409 |
+
return recon_audio
|
| 410 |
+
|
| 411 |
+
if self.num_stems > 1:
|
| 412 |
+
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
| 413 |
+
|
| 414 |
+
if target.ndim == 2:
|
| 415 |
+
target = rearrange(target, "... t -> ... 1 t")
|
| 416 |
+
|
| 417 |
+
target = target[..., : recon_audio.shape[-1]]
|
| 418 |
+
|
| 419 |
+
loss = F.l1_loss(recon_audio, target)
|
| 420 |
+
|
| 421 |
+
multi_stft_resolution_loss = 0.0
|
| 422 |
+
|
| 423 |
+
for window_size in self.multi_stft_resolutions_window_sizes:
|
| 424 |
+
res_stft_kwargs = dict(
|
| 425 |
+
n_fft=max(window_size, self.multi_stft_n_fft), win_length=window_size, return_complex=True, window=self.multi_stft_window_fn(window_size, device=device), **self.multi_stft_kwargs
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
|
| 429 |
+
target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)
|
| 430 |
+
|
| 431 |
+
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
| 432 |
+
|
| 433 |
+
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
| 434 |
+
|
| 435 |
+
total_loss = loss + weighted_multi_resolution_loss
|
| 436 |
+
|
| 437 |
+
# Move the total loss back to the original device if necessary
|
| 438 |
+
if x_is_mps:
|
| 439 |
+
total_loss = total_loss.to(original_device)
|
| 440 |
+
|
| 441 |
+
if not return_loss_breakdown:
|
| 442 |
+
return total_loss
|
| 443 |
+
|
| 444 |
+
# If detailed loss breakdown is requested, ensure all components are on the original device
|
| 445 |
+
return total_loss, (loss.to(original_device) if x_is_mps else loss, multi_stft_resolution_loss.to(original_device) if x_is_mps else multi_stft_resolution_loss)
|
audio_separator/separator/uvr_lib_v5/spec_utils.py
ADDED
|
@@ -0,0 +1,1327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import audioread
|
| 2 |
+
import librosa
|
| 3 |
+
import numpy as np
|
| 4 |
+
import soundfile as sf
|
| 5 |
+
import math
|
| 6 |
+
import platform
|
| 7 |
+
import traceback
|
| 8 |
+
from audio_separator.separator.uvr_lib_v5 import pyrb
|
| 9 |
+
from scipy.signal import correlate, hilbert
|
| 10 |
+
import io
|
| 11 |
+
|
| 12 |
+
OPERATING_SYSTEM = platform.system()
|
| 13 |
+
SYSTEM_ARCH = platform.platform()
|
| 14 |
+
SYSTEM_PROC = platform.processor()
|
| 15 |
+
ARM = "arm"
|
| 16 |
+
|
| 17 |
+
AUTO_PHASE = "Automatic"
|
| 18 |
+
POSITIVE_PHASE = "Positive Phase"
|
| 19 |
+
NEGATIVE_PHASE = "Negative Phase"
|
| 20 |
+
NONE_P = ("None",)
|
| 21 |
+
LOW_P = ("Shifts: Low",)
|
| 22 |
+
MED_P = ("Shifts: Medium",)
|
| 23 |
+
HIGH_P = ("Shifts: High",)
|
| 24 |
+
VHIGH_P = "Shifts: Very High"
|
| 25 |
+
MAXIMUM_P = "Shifts: Maximum"
|
| 26 |
+
|
| 27 |
+
progress_value = 0
|
| 28 |
+
last_update_time = 0
|
| 29 |
+
is_macos = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if OPERATING_SYSTEM == "Darwin":
|
| 33 |
+
wav_resolution = "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else "sinc_fastest"
|
| 34 |
+
wav_resolution_float_resampling = "kaiser_best" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else wav_resolution
|
| 35 |
+
is_macos = True
|
| 36 |
+
else:
|
| 37 |
+
wav_resolution = "sinc_fastest"
|
| 38 |
+
wav_resolution_float_resampling = wav_resolution
|
| 39 |
+
|
| 40 |
+
MAX_SPEC = "Max Spec"
|
| 41 |
+
MIN_SPEC = "Min Spec"
|
| 42 |
+
LIN_ENSE = "Linear Ensemble"
|
| 43 |
+
|
| 44 |
+
MAX_WAV = MAX_SPEC
|
| 45 |
+
MIN_WAV = MIN_SPEC
|
| 46 |
+
|
| 47 |
+
AVERAGE = "Average"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def crop_center(h1, h2):
|
| 51 |
+
"""
|
| 52 |
+
This function crops the center of the first input tensor to match the size of the second input tensor.
|
| 53 |
+
It is used to ensure that the two tensors have the same size in the time dimension.
|
| 54 |
+
"""
|
| 55 |
+
h1_shape = h1.size()
|
| 56 |
+
h2_shape = h2.size()
|
| 57 |
+
|
| 58 |
+
# If the time dimensions are already equal, return the first tensor as is
|
| 59 |
+
if h1_shape[3] == h2_shape[3]:
|
| 60 |
+
return h1
|
| 61 |
+
# If the time dimension of the first tensor is smaller, raise an error
|
| 62 |
+
elif h1_shape[3] < h2_shape[3]:
|
| 63 |
+
raise ValueError("h1_shape[3] must be greater than h2_shape[3]")
|
| 64 |
+
|
| 65 |
+
# Calculate the start and end indices for cropping
|
| 66 |
+
s_time = (h1_shape[3] - h2_shape[3]) // 2
|
| 67 |
+
e_time = s_time + h2_shape[3]
|
| 68 |
+
# Crop the first tensor
|
| 69 |
+
h1 = h1[:, :, :, s_time:e_time]
|
| 70 |
+
|
| 71 |
+
return h1
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def preprocess(X_spec):
|
| 75 |
+
"""
|
| 76 |
+
This function preprocesses a spectrogram by separating it into magnitude and phase components.
|
| 77 |
+
This is a common preprocessing step in audio processing tasks.
|
| 78 |
+
"""
|
| 79 |
+
X_mag = np.abs(X_spec)
|
| 80 |
+
X_phase = np.angle(X_spec)
|
| 81 |
+
|
| 82 |
+
return X_mag, X_phase
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def make_padding(width, cropsize, offset):
|
| 86 |
+
"""
|
| 87 |
+
This function calculates the padding needed to make the width of an image divisible by the crop size.
|
| 88 |
+
It is used in the process of splitting an image into smaller patches.
|
| 89 |
+
"""
|
| 90 |
+
left = offset
|
| 91 |
+
roi_size = cropsize - offset * 2
|
| 92 |
+
if roi_size == 0:
|
| 93 |
+
roi_size = cropsize
|
| 94 |
+
right = roi_size - (width % roi_size) + left
|
| 95 |
+
|
| 96 |
+
return left, right, roi_size
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def normalize(wave, max_peak=1.0, min_peak=None):
|
| 100 |
+
"""Normalize (or amplify) audio waveform to a specified peak value.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
wave (array-like): Audio waveform.
|
| 104 |
+
max_peak (float): Maximum peak value for normalization.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
array-like: Normalized or original waveform.
|
| 108 |
+
"""
|
| 109 |
+
maxv = np.abs(wave).max()
|
| 110 |
+
if maxv > max_peak:
|
| 111 |
+
wave *= max_peak / maxv
|
| 112 |
+
elif min_peak is not None and maxv < min_peak:
|
| 113 |
+
wave *= min_peak / maxv
|
| 114 |
+
|
| 115 |
+
return wave
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def auto_transpose(audio_array: np.ndarray):
|
| 119 |
+
"""
|
| 120 |
+
Ensure that the audio array is in the (channels, samples) format.
|
| 121 |
+
|
| 122 |
+
Parameters:
|
| 123 |
+
audio_array (ndarray): Input audio array.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
ndarray: Transposed audio array if necessary.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# If the second dimension is 2 (indicating stereo channels), transpose the array
|
| 130 |
+
if audio_array.shape[1] == 2:
|
| 131 |
+
return audio_array.T
|
| 132 |
+
return audio_array
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def write_array_to_mem(audio_data, subtype):
|
| 136 |
+
if isinstance(audio_data, np.ndarray):
|
| 137 |
+
audio_buffer = io.BytesIO()
|
| 138 |
+
sf.write(audio_buffer, audio_data, 44100, subtype=subtype, format="WAV")
|
| 139 |
+
audio_buffer.seek(0)
|
| 140 |
+
return audio_buffer
|
| 141 |
+
else:
|
| 142 |
+
return audio_data
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def spectrogram_to_image(spec, mode="magnitude"):
|
| 146 |
+
if mode == "magnitude":
|
| 147 |
+
if np.iscomplexobj(spec):
|
| 148 |
+
y = np.abs(spec)
|
| 149 |
+
else:
|
| 150 |
+
y = spec
|
| 151 |
+
y = np.log10(y**2 + 1e-8)
|
| 152 |
+
elif mode == "phase":
|
| 153 |
+
if np.iscomplexobj(spec):
|
| 154 |
+
y = np.angle(spec)
|
| 155 |
+
else:
|
| 156 |
+
y = spec
|
| 157 |
+
|
| 158 |
+
y -= y.min()
|
| 159 |
+
y *= 255 / y.max()
|
| 160 |
+
img = np.uint8(y)
|
| 161 |
+
|
| 162 |
+
if y.ndim == 3:
|
| 163 |
+
img = img.transpose(1, 2, 0)
|
| 164 |
+
img = np.concatenate([np.max(img, axis=2, keepdims=True), img], axis=2)
|
| 165 |
+
|
| 166 |
+
return img
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def reduce_vocal_aggressively(X, y, softmask):
|
| 170 |
+
v = X - y
|
| 171 |
+
y_mag_tmp = np.abs(y)
|
| 172 |
+
v_mag_tmp = np.abs(v)
|
| 173 |
+
|
| 174 |
+
v_mask = v_mag_tmp > y_mag_tmp
|
| 175 |
+
y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf)
|
| 176 |
+
|
| 177 |
+
return y_mag * np.exp(1.0j * np.angle(y))
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32):
|
| 181 |
+
mask = y_mask
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
if min_range < fade_size * 2:
|
| 185 |
+
raise ValueError("min_range must be >= fade_size * 2")
|
| 186 |
+
|
| 187 |
+
idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
|
| 188 |
+
start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
|
| 189 |
+
end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
|
| 190 |
+
artifact_idx = np.where(end_idx - start_idx > min_range)[0]
|
| 191 |
+
weight = np.zeros_like(y_mask)
|
| 192 |
+
if len(artifact_idx) > 0:
|
| 193 |
+
start_idx = start_idx[artifact_idx]
|
| 194 |
+
end_idx = end_idx[artifact_idx]
|
| 195 |
+
old_e = None
|
| 196 |
+
for s, e in zip(start_idx, end_idx):
|
| 197 |
+
if old_e is not None and s - old_e < fade_size:
|
| 198 |
+
s = old_e - fade_size * 2
|
| 199 |
+
|
| 200 |
+
if s != 0:
|
| 201 |
+
weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size)
|
| 202 |
+
else:
|
| 203 |
+
s -= fade_size
|
| 204 |
+
|
| 205 |
+
if e != y_mask.shape[2]:
|
| 206 |
+
weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size)
|
| 207 |
+
else:
|
| 208 |
+
e += fade_size
|
| 209 |
+
|
| 210 |
+
weight[:, :, s + fade_size : e - fade_size] = 1
|
| 211 |
+
old_e = e
|
| 212 |
+
|
| 213 |
+
v_mask = 1 - y_mask
|
| 214 |
+
y_mask += weight * v_mask
|
| 215 |
+
|
| 216 |
+
mask = y_mask
|
| 217 |
+
except Exception as e:
|
| 218 |
+
error_name = f"{type(e).__name__}"
|
| 219 |
+
traceback_text = "".join(traceback.format_tb(e.__traceback__))
|
| 220 |
+
message = f'{error_name}: "{e}"\n{traceback_text}"'
|
| 221 |
+
print("Post Process Failed: ", message)
|
| 222 |
+
|
| 223 |
+
return mask
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def align_wave_head_and_tail(a, b):
|
| 227 |
+
l = min([a[0].size, b[0].size])
|
| 228 |
+
|
| 229 |
+
return a[:l, :l], b[:l, :l]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def convert_channels(spec, mp, band):
|
| 233 |
+
cc = mp.param["band"][band].get("convert_channels")
|
| 234 |
+
|
| 235 |
+
if "mid_side_c" == cc:
|
| 236 |
+
spec_left = np.add(spec[0], spec[1] * 0.25)
|
| 237 |
+
spec_right = np.subtract(spec[1], spec[0] * 0.25)
|
| 238 |
+
elif "mid_side" == cc:
|
| 239 |
+
spec_left = np.add(spec[0], spec[1]) / 2
|
| 240 |
+
spec_right = np.subtract(spec[0], spec[1])
|
| 241 |
+
elif "stereo_n" == cc:
|
| 242 |
+
spec_left = np.add(spec[0], spec[1] * 0.25) / 0.9375
|
| 243 |
+
spec_right = np.add(spec[1], spec[0] * 0.25) / 0.9375
|
| 244 |
+
else:
|
| 245 |
+
return spec
|
| 246 |
+
|
| 247 |
+
return np.asfortranarray([spec_left, spec_right])
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def combine_spectrograms(specs, mp, is_v51_model=False):
|
| 251 |
+
l = min([specs[i].shape[2] for i in specs])
|
| 252 |
+
spec_c = np.zeros(shape=(2, mp.param["bins"] + 1, l), dtype=np.complex64)
|
| 253 |
+
offset = 0
|
| 254 |
+
bands_n = len(mp.param["band"])
|
| 255 |
+
|
| 256 |
+
for d in range(1, bands_n + 1):
|
| 257 |
+
h = mp.param["band"][d]["crop_stop"] - mp.param["band"][d]["crop_start"]
|
| 258 |
+
spec_c[:, offset : offset + h, :l] = specs[d][:, mp.param["band"][d]["crop_start"] : mp.param["band"][d]["crop_stop"], :l]
|
| 259 |
+
offset += h
|
| 260 |
+
|
| 261 |
+
if offset > mp.param["bins"]:
|
| 262 |
+
raise ValueError("Too much bins")
|
| 263 |
+
|
| 264 |
+
# lowpass fiter
|
| 265 |
+
|
| 266 |
+
if mp.param["pre_filter_start"] > 0:
|
| 267 |
+
if is_v51_model:
|
| 268 |
+
spec_c *= get_lp_filter_mask(spec_c.shape[1], mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
|
| 269 |
+
else:
|
| 270 |
+
if bands_n == 1:
|
| 271 |
+
spec_c = fft_lp_filter(spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
|
| 272 |
+
else:
|
| 273 |
+
gp = 1
|
| 274 |
+
for b in range(mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"]):
|
| 275 |
+
g = math.pow(10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0)
|
| 276 |
+
gp = g
|
| 277 |
+
spec_c[:, b, :] *= g
|
| 278 |
+
|
| 279 |
+
return np.asfortranarray(spec_c)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def wave_to_spectrogram(wave, hop_length, n_fft, mp, band, is_v51_model=False):
|
| 283 |
+
|
| 284 |
+
if wave.ndim == 1:
|
| 285 |
+
wave = np.asfortranarray([wave, wave])
|
| 286 |
+
|
| 287 |
+
if not is_v51_model:
|
| 288 |
+
if mp.param["reverse"]:
|
| 289 |
+
wave_left = np.flip(np.asfortranarray(wave[0]))
|
| 290 |
+
wave_right = np.flip(np.asfortranarray(wave[1]))
|
| 291 |
+
elif mp.param["mid_side"]:
|
| 292 |
+
wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
|
| 293 |
+
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
|
| 294 |
+
elif mp.param["mid_side_b2"]:
|
| 295 |
+
wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5))
|
| 296 |
+
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5))
|
| 297 |
+
else:
|
| 298 |
+
wave_left = np.asfortranarray(wave[0])
|
| 299 |
+
wave_right = np.asfortranarray(wave[1])
|
| 300 |
+
else:
|
| 301 |
+
wave_left = np.asfortranarray(wave[0])
|
| 302 |
+
wave_right = np.asfortranarray(wave[1])
|
| 303 |
+
|
| 304 |
+
spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
|
| 305 |
+
spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
|
| 306 |
+
|
| 307 |
+
spec = np.asfortranarray([spec_left, spec_right])
|
| 308 |
+
|
| 309 |
+
if is_v51_model:
|
| 310 |
+
spec = convert_channels(spec, mp, band)
|
| 311 |
+
|
| 312 |
+
return spec
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def spectrogram_to_wave(spec, hop_length=1024, mp={}, band=0, is_v51_model=True):
|
| 316 |
+
spec_left = np.asfortranarray(spec[0])
|
| 317 |
+
spec_right = np.asfortranarray(spec[1])
|
| 318 |
+
|
| 319 |
+
wave_left = librosa.istft(spec_left, hop_length=hop_length)
|
| 320 |
+
wave_right = librosa.istft(spec_right, hop_length=hop_length)
|
| 321 |
+
|
| 322 |
+
if is_v51_model:
|
| 323 |
+
cc = mp.param["band"][band].get("convert_channels")
|
| 324 |
+
if "mid_side_c" == cc:
|
| 325 |
+
return np.asfortranarray([np.subtract(wave_left / 1.0625, wave_right / 4.25), np.add(wave_right / 1.0625, wave_left / 4.25)])
|
| 326 |
+
elif "mid_side" == cc:
|
| 327 |
+
return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
|
| 328 |
+
elif "stereo_n" == cc:
|
| 329 |
+
return np.asfortranarray([np.subtract(wave_left, wave_right * 0.25), np.subtract(wave_right, wave_left * 0.25)])
|
| 330 |
+
else:
|
| 331 |
+
if mp.param["reverse"]:
|
| 332 |
+
return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
|
| 333 |
+
elif mp.param["mid_side"]:
|
| 334 |
+
return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
|
| 335 |
+
elif mp.param["mid_side_b2"]:
|
| 336 |
+
return np.asfortranarray([np.add(wave_right / 1.25, 0.4 * wave_left), np.subtract(wave_left / 1.25, 0.4 * wave_right)])
|
| 337 |
+
|
| 338 |
+
return np.asfortranarray([wave_left, wave_right])
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None, is_v51_model=False):
|
| 342 |
+
bands_n = len(mp.param["band"])
|
| 343 |
+
offset = 0
|
| 344 |
+
|
| 345 |
+
for d in range(1, bands_n + 1):
|
| 346 |
+
bp = mp.param["band"][d]
|
| 347 |
+
spec_s = np.zeros(shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex)
|
| 348 |
+
h = bp["crop_stop"] - bp["crop_start"]
|
| 349 |
+
spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[:, offset : offset + h, :]
|
| 350 |
+
|
| 351 |
+
offset += h
|
| 352 |
+
if d == bands_n: # higher
|
| 353 |
+
if extra_bins_h: # if --high_end_process bypass
|
| 354 |
+
max_bin = bp["n_fft"] // 2
|
| 355 |
+
spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[:, :extra_bins_h, :]
|
| 356 |
+
if bp["hpf_start"] > 0:
|
| 357 |
+
if is_v51_model:
|
| 358 |
+
spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
|
| 359 |
+
else:
|
| 360 |
+
spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
|
| 361 |
+
if bands_n == 1:
|
| 362 |
+
wave = spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model)
|
| 363 |
+
else:
|
| 364 |
+
wave = np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
|
| 365 |
+
else:
|
| 366 |
+
sr = mp.param["band"][d + 1]["sr"]
|
| 367 |
+
if d == 1: # lower
|
| 368 |
+
if is_v51_model:
|
| 369 |
+
spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
|
| 370 |
+
else:
|
| 371 |
+
spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
|
| 372 |
+
|
| 373 |
+
try:
|
| 374 |
+
wave = librosa.resample(spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model), orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
|
| 375 |
+
except ValueError as e:
|
| 376 |
+
print(f"Error during resampling: {e}")
|
| 377 |
+
print(f"Spec_s shape: {spec_s.shape}, SR: {sr}, Res type: {wav_resolution}")
|
| 378 |
+
|
| 379 |
+
else: # mid
|
| 380 |
+
if is_v51_model:
|
| 381 |
+
spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
|
| 382 |
+
spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
|
| 383 |
+
else:
|
| 384 |
+
spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
|
| 385 |
+
spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
|
| 386 |
+
|
| 387 |
+
wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
|
| 388 |
+
|
| 389 |
+
try:
|
| 390 |
+
wave = librosa.resample(wave2, orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
|
| 391 |
+
except ValueError as e:
|
| 392 |
+
print(f"Error during resampling: {e}")
|
| 393 |
+
print(f"Spec_s shape: {spec_s.shape}, SR: {sr}, Res type: {wav_resolution}")
|
| 394 |
+
|
| 395 |
+
return wave
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def get_lp_filter_mask(n_bins, bin_start, bin_stop):
|
| 399 |
+
mask = np.concatenate([np.ones((bin_start - 1, 1)), np.linspace(1, 0, bin_stop - bin_start + 1)[:, None], np.zeros((n_bins - bin_stop, 1))], axis=0)
|
| 400 |
+
|
| 401 |
+
return mask
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def get_hp_filter_mask(n_bins, bin_start, bin_stop):
|
| 405 |
+
mask = np.concatenate([np.zeros((bin_stop + 1, 1)), np.linspace(0, 1, 1 + bin_start - bin_stop)[:, None], np.ones((n_bins - bin_start - 2, 1))], axis=0)
|
| 406 |
+
|
| 407 |
+
return mask
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def fft_lp_filter(spec, bin_start, bin_stop):
|
| 411 |
+
g = 1.0
|
| 412 |
+
for b in range(bin_start, bin_stop):
|
| 413 |
+
g -= 1 / (bin_stop - bin_start)
|
| 414 |
+
spec[:, b, :] = g * spec[:, b, :]
|
| 415 |
+
|
| 416 |
+
spec[:, bin_stop:, :] *= 0
|
| 417 |
+
|
| 418 |
+
return spec
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def fft_hp_filter(spec, bin_start, bin_stop):
|
| 422 |
+
g = 1.0
|
| 423 |
+
for b in range(bin_start, bin_stop, -1):
|
| 424 |
+
g -= 1 / (bin_start - bin_stop)
|
| 425 |
+
spec[:, b, :] = g * spec[:, b, :]
|
| 426 |
+
|
| 427 |
+
spec[:, 0 : bin_stop + 1, :] *= 0
|
| 428 |
+
|
| 429 |
+
return spec
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def spectrogram_to_wave_old(spec, hop_length=1024):
|
| 433 |
+
if spec.ndim == 2:
|
| 434 |
+
wave = librosa.istft(spec, hop_length=hop_length)
|
| 435 |
+
elif spec.ndim == 3:
|
| 436 |
+
spec_left = np.asfortranarray(spec[0])
|
| 437 |
+
spec_right = np.asfortranarray(spec[1])
|
| 438 |
+
|
| 439 |
+
wave_left = librosa.istft(spec_left, hop_length=hop_length)
|
| 440 |
+
wave_right = librosa.istft(spec_right, hop_length=hop_length)
|
| 441 |
+
wave = np.asfortranarray([wave_left, wave_right])
|
| 442 |
+
|
| 443 |
+
return wave
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def wave_to_spectrogram_old(wave, hop_length, n_fft):
|
| 447 |
+
wave_left = np.asfortranarray(wave[0])
|
| 448 |
+
wave_right = np.asfortranarray(wave[1])
|
| 449 |
+
|
| 450 |
+
spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
|
| 451 |
+
spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
|
| 452 |
+
|
| 453 |
+
spec = np.asfortranarray([spec_left, spec_right])
|
| 454 |
+
|
| 455 |
+
return spec
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def mirroring(a, spec_m, input_high_end, mp):
|
| 459 |
+
if "mirroring" == a:
|
| 460 |
+
mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1)
|
| 461 |
+
mirror = mirror * np.exp(1.0j * np.angle(input_high_end))
|
| 462 |
+
|
| 463 |
+
return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
|
| 464 |
+
|
| 465 |
+
if "mirroring2" == a:
|
| 466 |
+
mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1)
|
| 467 |
+
mi = np.multiply(mirror, input_high_end * 1.7)
|
| 468 |
+
|
| 469 |
+
return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def adjust_aggr(mask, is_non_accom_stem, aggressiveness):
|
| 473 |
+
aggr = aggressiveness["value"] * 2
|
| 474 |
+
|
| 475 |
+
if aggr != 0:
|
| 476 |
+
if is_non_accom_stem:
|
| 477 |
+
aggr = 1 - aggr
|
| 478 |
+
|
| 479 |
+
if np.any(aggr > 10) or np.any(aggr < -10):
|
| 480 |
+
print(f"Warning: Extreme aggressiveness values detected: {aggr}")
|
| 481 |
+
|
| 482 |
+
aggr = [aggr, aggr]
|
| 483 |
+
|
| 484 |
+
if aggressiveness["aggr_correction"] is not None:
|
| 485 |
+
aggr[0] += aggressiveness["aggr_correction"]["left"]
|
| 486 |
+
aggr[1] += aggressiveness["aggr_correction"]["right"]
|
| 487 |
+
|
| 488 |
+
for ch in range(2):
|
| 489 |
+
mask[ch, : aggressiveness["split_bin"]] = np.power(mask[ch, : aggressiveness["split_bin"]], 1 + aggr[ch] / 3)
|
| 490 |
+
mask[ch, aggressiveness["split_bin"] :] = np.power(mask[ch, aggressiveness["split_bin"] :], 1 + aggr[ch])
|
| 491 |
+
|
| 492 |
+
return mask
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def stft(wave, nfft, hl):
|
| 496 |
+
wave_left = np.asfortranarray(wave[0])
|
| 497 |
+
wave_right = np.asfortranarray(wave[1])
|
| 498 |
+
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
|
| 499 |
+
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
|
| 500 |
+
spec = np.asfortranarray([spec_left, spec_right])
|
| 501 |
+
|
| 502 |
+
return spec
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def istft(spec, hl):
|
| 506 |
+
spec_left = np.asfortranarray(spec[0])
|
| 507 |
+
spec_right = np.asfortranarray(spec[1])
|
| 508 |
+
wave_left = librosa.istft(spec_left, hop_length=hl)
|
| 509 |
+
wave_right = librosa.istft(spec_right, hop_length=hl)
|
| 510 |
+
wave = np.asfortranarray([wave_left, wave_right])
|
| 511 |
+
|
| 512 |
+
return wave
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def spec_effects(wave, algorithm="Default", value=None):
|
| 516 |
+
if np.isnan(wave).any() or np.isinf(wave).any():
|
| 517 |
+
print(f"Warning: Detected NaN or infinite values in wave input. Shape: {wave.shape}")
|
| 518 |
+
|
| 519 |
+
spec = [stft(wave[0], 2048, 1024), stft(wave[1], 2048, 1024)]
|
| 520 |
+
if algorithm == "Min_Mag":
|
| 521 |
+
v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0])
|
| 522 |
+
wave = istft(v_spec_m, 1024)
|
| 523 |
+
elif algorithm == "Max_Mag":
|
| 524 |
+
v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0])
|
| 525 |
+
wave = istft(v_spec_m, 1024)
|
| 526 |
+
elif algorithm == "Default":
|
| 527 |
+
wave = (wave[1] * value) + (wave[0] * (1 - value))
|
| 528 |
+
elif algorithm == "Invert_p":
|
| 529 |
+
X_mag = np.abs(spec[0])
|
| 530 |
+
y_mag = np.abs(spec[1])
|
| 531 |
+
max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
|
| 532 |
+
v_spec = spec[1] - max_mag * np.exp(1.0j * np.angle(spec[0]))
|
| 533 |
+
wave = istft(v_spec, 1024)
|
| 534 |
+
|
| 535 |
+
return wave
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024):
|
| 539 |
+
wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length)
|
| 540 |
+
|
| 541 |
+
if wave.ndim == 1:
|
| 542 |
+
wave = np.asfortranarray([wave, wave])
|
| 543 |
+
|
| 544 |
+
return wave
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def wave_to_spectrogram_no_mp(wave):
|
| 548 |
+
|
| 549 |
+
spec = librosa.stft(wave, n_fft=2048, hop_length=1024)
|
| 550 |
+
|
| 551 |
+
if spec.ndim == 1:
|
| 552 |
+
spec = np.asfortranarray([spec, spec])
|
| 553 |
+
|
| 554 |
+
return spec
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def invert_audio(specs, invert_p=True):
|
| 558 |
+
|
| 559 |
+
ln = min([specs[0].shape[2], specs[1].shape[2]])
|
| 560 |
+
specs[0] = specs[0][:, :, :ln]
|
| 561 |
+
specs[1] = specs[1][:, :, :ln]
|
| 562 |
+
|
| 563 |
+
if invert_p:
|
| 564 |
+
X_mag = np.abs(specs[0])
|
| 565 |
+
y_mag = np.abs(specs[1])
|
| 566 |
+
max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
|
| 567 |
+
v_spec = specs[1] - max_mag * np.exp(1.0j * np.angle(specs[0]))
|
| 568 |
+
else:
|
| 569 |
+
specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2)
|
| 570 |
+
v_spec = specs[0] - specs[1]
|
| 571 |
+
|
| 572 |
+
return v_spec
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def invert_stem(mixture, stem):
|
| 576 |
+
mixture = wave_to_spectrogram_no_mp(mixture)
|
| 577 |
+
stem = wave_to_spectrogram_no_mp(stem)
|
| 578 |
+
output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem]))
|
| 579 |
+
|
| 580 |
+
return -output.T
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def ensembling(a, inputs, is_wavs=False):
|
| 584 |
+
|
| 585 |
+
for i in range(1, len(inputs)):
|
| 586 |
+
if i == 1:
|
| 587 |
+
input = inputs[0]
|
| 588 |
+
|
| 589 |
+
if is_wavs:
|
| 590 |
+
ln = min([input.shape[1], inputs[i].shape[1]])
|
| 591 |
+
input = input[:, :ln]
|
| 592 |
+
inputs[i] = inputs[i][:, :ln]
|
| 593 |
+
else:
|
| 594 |
+
ln = min([input.shape[2], inputs[i].shape[2]])
|
| 595 |
+
input = input[:, :, :ln]
|
| 596 |
+
inputs[i] = inputs[i][:, :, :ln]
|
| 597 |
+
|
| 598 |
+
if MIN_SPEC == a:
|
| 599 |
+
input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input)
|
| 600 |
+
if MAX_SPEC == a:
|
| 601 |
+
#input = np.array(np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), inputs[i], input), dtype=object)
|
| 602 |
+
input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input)
|
| 603 |
+
#max_spec = np.array([np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), s, specs[0]) for s in specs[1:]], dtype=object)[-1]
|
| 604 |
+
|
| 605 |
+
# linear_ensemble
|
| 606 |
+
# input = ensemble_wav(inputs, split_size=1)
|
| 607 |
+
|
| 608 |
+
return input
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def ensemble_for_align(waves):
|
| 612 |
+
|
| 613 |
+
specs = []
|
| 614 |
+
|
| 615 |
+
for wav in waves:
|
| 616 |
+
spec = wave_to_spectrogram_no_mp(wav.T)
|
| 617 |
+
specs.append(spec)
|
| 618 |
+
|
| 619 |
+
wav_aligned = spectrogram_to_wave_no_mp(ensembling(MIN_SPEC, specs)).T
|
| 620 |
+
wav_aligned = match_array_shapes(wav_aligned, waves[1], is_swap=True)
|
| 621 |
+
|
| 622 |
+
return wav_aligned
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path, is_wave=False, is_array=False):
|
| 626 |
+
|
| 627 |
+
wavs_ = []
|
| 628 |
+
|
| 629 |
+
if algorithm == AVERAGE:
|
| 630 |
+
output = average_audio(audio_input)
|
| 631 |
+
samplerate = 44100
|
| 632 |
+
else:
|
| 633 |
+
specs = []
|
| 634 |
+
|
| 635 |
+
for i in range(len(audio_input)):
|
| 636 |
+
wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100)
|
| 637 |
+
wavs_.append(wave)
|
| 638 |
+
spec = wave if is_wave else wave_to_spectrogram_no_mp(wave)
|
| 639 |
+
specs.append(spec)
|
| 640 |
+
|
| 641 |
+
wave_shapes = [w.shape[1] for w in wavs_]
|
| 642 |
+
target_shape = wavs_[wave_shapes.index(max(wave_shapes))]
|
| 643 |
+
|
| 644 |
+
if is_wave:
|
| 645 |
+
output = ensembling(algorithm, specs, is_wavs=True)
|
| 646 |
+
else:
|
| 647 |
+
output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs))
|
| 648 |
+
|
| 649 |
+
output = to_shape(output, target_shape.shape)
|
| 650 |
+
|
| 651 |
+
sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def to_shape(x, target_shape):
|
| 655 |
+
padding_list = []
|
| 656 |
+
for x_dim, target_dim in zip(x.shape, target_shape):
|
| 657 |
+
pad_value = target_dim - x_dim
|
| 658 |
+
pad_tuple = (0, pad_value)
|
| 659 |
+
padding_list.append(pad_tuple)
|
| 660 |
+
|
| 661 |
+
return np.pad(x, tuple(padding_list), mode="constant")
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def to_shape_minimize(x: np.ndarray, target_shape):
|
| 665 |
+
|
| 666 |
+
padding_list = []
|
| 667 |
+
for x_dim, target_dim in zip(x.shape, target_shape):
|
| 668 |
+
pad_value = target_dim - x_dim
|
| 669 |
+
pad_tuple = (0, pad_value)
|
| 670 |
+
padding_list.append(pad_tuple)
|
| 671 |
+
|
| 672 |
+
return np.pad(x, tuple(padding_list), mode="constant")
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def detect_leading_silence(audio, sr, silence_threshold=0.007, frame_length=1024):
|
| 676 |
+
"""
|
| 677 |
+
Detect silence at the beginning of an audio signal.
|
| 678 |
+
|
| 679 |
+
:param audio: np.array, audio signal
|
| 680 |
+
:param sr: int, sample rate
|
| 681 |
+
:param silence_threshold: float, magnitude threshold below which is considered silence
|
| 682 |
+
:param frame_length: int, the number of samples to consider for each check
|
| 683 |
+
|
| 684 |
+
:return: float, duration of the leading silence in milliseconds
|
| 685 |
+
"""
|
| 686 |
+
|
| 687 |
+
if len(audio.shape) == 2:
|
| 688 |
+
# If stereo, pick the channel with more energy to determine the silence
|
| 689 |
+
channel = np.argmax(np.sum(np.abs(audio), axis=1))
|
| 690 |
+
audio = audio[channel]
|
| 691 |
+
|
| 692 |
+
for i in range(0, len(audio), frame_length):
|
| 693 |
+
if np.max(np.abs(audio[i : i + frame_length])) > silence_threshold:
|
| 694 |
+
return (i / sr) * 1000
|
| 695 |
+
|
| 696 |
+
return (len(audio) / sr) * 1000
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def adjust_leading_silence(target_audio, reference_audio, silence_threshold=0.01, frame_length=1024):
|
| 700 |
+
"""
|
| 701 |
+
Adjust the leading silence of the target_audio to match the leading silence of the reference_audio.
|
| 702 |
+
|
| 703 |
+
:param target_audio: np.array, audio signal that will have its silence adjusted
|
| 704 |
+
:param reference_audio: np.array, audio signal used as a reference
|
| 705 |
+
:param sr: int, sample rate
|
| 706 |
+
:param silence_threshold: float, magnitude threshold below which is considered silence
|
| 707 |
+
:param frame_length: int, the number of samples to consider for each check
|
| 708 |
+
|
| 709 |
+
:return: np.array, target_audio adjusted to have the same leading silence as reference_audio
|
| 710 |
+
"""
|
| 711 |
+
|
| 712 |
+
def find_silence_end(audio):
|
| 713 |
+
if len(audio.shape) == 2:
|
| 714 |
+
# If stereo, pick the channel with more energy to determine the silence
|
| 715 |
+
channel = np.argmax(np.sum(np.abs(audio), axis=1))
|
| 716 |
+
audio_mono = audio[channel]
|
| 717 |
+
else:
|
| 718 |
+
audio_mono = audio
|
| 719 |
+
|
| 720 |
+
for i in range(0, len(audio_mono), frame_length):
|
| 721 |
+
if np.max(np.abs(audio_mono[i : i + frame_length])) > silence_threshold:
|
| 722 |
+
return i
|
| 723 |
+
return len(audio_mono)
|
| 724 |
+
|
| 725 |
+
ref_silence_end = find_silence_end(reference_audio)
|
| 726 |
+
target_silence_end = find_silence_end(target_audio)
|
| 727 |
+
silence_difference = ref_silence_end - target_silence_end
|
| 728 |
+
|
| 729 |
+
try:
|
| 730 |
+
ref_silence_end_p = (ref_silence_end / 44100) * 1000
|
| 731 |
+
target_silence_end_p = (target_silence_end / 44100) * 1000
|
| 732 |
+
silence_difference_p = ref_silence_end_p - target_silence_end_p
|
| 733 |
+
print("silence_difference: ", silence_difference_p)
|
| 734 |
+
except Exception as e:
|
| 735 |
+
pass
|
| 736 |
+
|
| 737 |
+
if silence_difference > 0: # Add silence to target_audio
|
| 738 |
+
if len(target_audio.shape) == 2: # stereo
|
| 739 |
+
silence_to_add = np.zeros((target_audio.shape[0], silence_difference))
|
| 740 |
+
else: # mono
|
| 741 |
+
silence_to_add = np.zeros(silence_difference)
|
| 742 |
+
return np.hstack((silence_to_add, target_audio))
|
| 743 |
+
elif silence_difference < 0: # Remove silence from target_audio
|
| 744 |
+
if len(target_audio.shape) == 2: # stereo
|
| 745 |
+
return target_audio[:, -silence_difference:]
|
| 746 |
+
else: # mono
|
| 747 |
+
return target_audio[-silence_difference:]
|
| 748 |
+
else: # No adjustment needed
|
| 749 |
+
return target_audio
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def match_array_shapes(array_1: np.ndarray, array_2: np.ndarray, is_swap=False):
|
| 753 |
+
|
| 754 |
+
if is_swap:
|
| 755 |
+
array_1, array_2 = array_1.T, array_2.T
|
| 756 |
+
|
| 757 |
+
# print("before", array_1.shape, array_2.shape)
|
| 758 |
+
if array_1.shape[1] > array_2.shape[1]:
|
| 759 |
+
array_1 = array_1[:, : array_2.shape[1]]
|
| 760 |
+
elif array_1.shape[1] < array_2.shape[1]:
|
| 761 |
+
padding = array_2.shape[1] - array_1.shape[1]
|
| 762 |
+
array_1 = np.pad(array_1, ((0, 0), (0, padding)), "constant", constant_values=0)
|
| 763 |
+
|
| 764 |
+
# print("after", array_1.shape, array_2.shape)
|
| 765 |
+
|
| 766 |
+
if is_swap:
|
| 767 |
+
array_1, array_2 = array_1.T, array_2.T
|
| 768 |
+
|
| 769 |
+
return array_1
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def match_mono_array_shapes(array_1: np.ndarray, array_2: np.ndarray):
|
| 773 |
+
|
| 774 |
+
if len(array_1) > len(array_2):
|
| 775 |
+
array_1 = array_1[: len(array_2)]
|
| 776 |
+
elif len(array_1) < len(array_2):
|
| 777 |
+
padding = len(array_2) - len(array_1)
|
| 778 |
+
array_1 = np.pad(array_1, (0, padding), "constant", constant_values=0)
|
| 779 |
+
|
| 780 |
+
return array_1
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def change_pitch_semitones(y, sr, semitone_shift):
|
| 784 |
+
factor = 2 ** (semitone_shift / 12) # Convert semitone shift to factor for resampling
|
| 785 |
+
y_pitch_tuned = []
|
| 786 |
+
for y_channel in y:
|
| 787 |
+
y_pitch_tuned.append(librosa.resample(y_channel, orig_sr=sr, target_sr=sr * factor, res_type=wav_resolution_float_resampling))
|
| 788 |
+
y_pitch_tuned = np.array(y_pitch_tuned)
|
| 789 |
+
new_sr = sr * factor
|
| 790 |
+
return y_pitch_tuned, new_sr
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False, is_time_correction=True):
|
| 794 |
+
|
| 795 |
+
wav, sr = librosa.load(audio_file, sr=44100, mono=False)
|
| 796 |
+
|
| 797 |
+
if wav.ndim == 1:
|
| 798 |
+
wav = np.asfortranarray([wav, wav])
|
| 799 |
+
|
| 800 |
+
if not is_time_correction:
|
| 801 |
+
wav_mix = change_pitch_semitones(wav, 44100, semitone_shift=-rate)[0]
|
| 802 |
+
else:
|
| 803 |
+
if is_pitch:
|
| 804 |
+
wav_1 = pyrb.pitch_shift(wav[0], sr, rate, rbargs=None)
|
| 805 |
+
wav_2 = pyrb.pitch_shift(wav[1], sr, rate, rbargs=None)
|
| 806 |
+
else:
|
| 807 |
+
wav_1 = pyrb.time_stretch(wav[0], sr, rate, rbargs=None)
|
| 808 |
+
wav_2 = pyrb.time_stretch(wav[1], sr, rate, rbargs=None)
|
| 809 |
+
|
| 810 |
+
if wav_1.shape > wav_2.shape:
|
| 811 |
+
wav_2 = to_shape(wav_2, wav_1.shape)
|
| 812 |
+
if wav_1.shape < wav_2.shape:
|
| 813 |
+
wav_1 = to_shape(wav_1, wav_2.shape)
|
| 814 |
+
|
| 815 |
+
wav_mix = np.asfortranarray([wav_1, wav_2])
|
| 816 |
+
|
| 817 |
+
sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set)
|
| 818 |
+
save_format(export_path)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
def average_audio(audio):
|
| 822 |
+
|
| 823 |
+
waves = []
|
| 824 |
+
wave_shapes = []
|
| 825 |
+
final_waves = []
|
| 826 |
+
|
| 827 |
+
for i in range(len(audio)):
|
| 828 |
+
wave = librosa.load(audio[i], sr=44100, mono=False)
|
| 829 |
+
waves.append(wave[0])
|
| 830 |
+
wave_shapes.append(wave[0].shape[1])
|
| 831 |
+
|
| 832 |
+
wave_shapes_index = wave_shapes.index(max(wave_shapes))
|
| 833 |
+
target_shape = waves[wave_shapes_index]
|
| 834 |
+
waves.pop(wave_shapes_index)
|
| 835 |
+
final_waves.append(target_shape)
|
| 836 |
+
|
| 837 |
+
for n_array in waves:
|
| 838 |
+
wav_target = to_shape(n_array, target_shape.shape)
|
| 839 |
+
final_waves.append(wav_target)
|
| 840 |
+
|
| 841 |
+
waves = sum(final_waves)
|
| 842 |
+
waves = waves / len(audio)
|
| 843 |
+
|
| 844 |
+
return waves
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
def average_dual_sources(wav_1, wav_2, value):
|
| 848 |
+
|
| 849 |
+
if wav_1.shape > wav_2.shape:
|
| 850 |
+
wav_2 = to_shape(wav_2, wav_1.shape)
|
| 851 |
+
if wav_1.shape < wav_2.shape:
|
| 852 |
+
wav_1 = to_shape(wav_1, wav_2.shape)
|
| 853 |
+
|
| 854 |
+
wave = (wav_1 * value) + (wav_2 * (1 - value))
|
| 855 |
+
|
| 856 |
+
return wave
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray):
|
| 860 |
+
|
| 861 |
+
if wav_1.shape > wav_2.shape:
|
| 862 |
+
wav_2 = to_shape(wav_2, wav_1.shape)
|
| 863 |
+
if wav_1.shape < wav_2.shape:
|
| 864 |
+
ln = min([wav_1.shape[1], wav_2.shape[1]])
|
| 865 |
+
wav_2 = wav_2[:, :ln]
|
| 866 |
+
|
| 867 |
+
ln = min([wav_1.shape[1], wav_2.shape[1]])
|
| 868 |
+
wav_1 = wav_1[:, :ln]
|
| 869 |
+
wav_2 = wav_2[:, :ln]
|
| 870 |
+
|
| 871 |
+
return wav_2
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def reshape_sources_ref(wav_1_shape, wav_2: np.ndarray):
|
| 875 |
+
|
| 876 |
+
if wav_1_shape > wav_2.shape:
|
| 877 |
+
wav_2 = to_shape(wav_2, wav_1_shape)
|
| 878 |
+
|
| 879 |
+
return wav_2
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
def combine_arrarys(audio_sources, is_swap=False):
|
| 883 |
+
source = np.zeros_like(max(audio_sources, key=np.size))
|
| 884 |
+
|
| 885 |
+
for v in audio_sources:
|
| 886 |
+
v = match_array_shapes(v, source, is_swap=is_swap)
|
| 887 |
+
source += v
|
| 888 |
+
|
| 889 |
+
return source
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def combine_audio(paths: list, audio_file_base=None, wav_type_set="FLOAT", save_format=None):
|
| 893 |
+
|
| 894 |
+
source = combine_arrarys([load_audio(i) for i in paths])
|
| 895 |
+
save_path = f"{audio_file_base}_combined.wav"
|
| 896 |
+
sf.write(save_path, source.T, 44100, subtype=wav_type_set)
|
| 897 |
+
save_format(save_path)
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
def reduce_mix_bv(inst_source, voc_source, reduction_rate=0.9):
|
| 901 |
+
# Reduce the volume
|
| 902 |
+
inst_source = inst_source * (1 - reduction_rate)
|
| 903 |
+
|
| 904 |
+
mix_reduced = combine_arrarys([inst_source, voc_source], is_swap=True)
|
| 905 |
+
|
| 906 |
+
return mix_reduced
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def organize_inputs(inputs):
|
| 910 |
+
input_list = {"target": None, "reference": None, "reverb": None, "inst": None}
|
| 911 |
+
|
| 912 |
+
for i in inputs:
|
| 913 |
+
if i.endswith("_(Vocals).wav"):
|
| 914 |
+
input_list["reference"] = i
|
| 915 |
+
elif "_RVC_" in i:
|
| 916 |
+
input_list["target"] = i
|
| 917 |
+
elif i.endswith("reverbed_stem.wav"):
|
| 918 |
+
input_list["reverb"] = i
|
| 919 |
+
elif i.endswith("_(Instrumental).wav"):
|
| 920 |
+
input_list["inst"] = i
|
| 921 |
+
|
| 922 |
+
return input_list
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
def check_if_phase_inverted(wav1, wav2, is_mono=False):
|
| 926 |
+
# Load the audio files
|
| 927 |
+
if not is_mono:
|
| 928 |
+
wav1 = np.mean(wav1, axis=0)
|
| 929 |
+
wav2 = np.mean(wav2, axis=0)
|
| 930 |
+
|
| 931 |
+
# Compute the correlation
|
| 932 |
+
correlation = np.corrcoef(wav1[:1000], wav2[:1000])
|
| 933 |
+
|
| 934 |
+
return correlation[0, 1] < 0
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
def align_audio(
|
| 938 |
+
file1,
|
| 939 |
+
file2,
|
| 940 |
+
file2_aligned,
|
| 941 |
+
file_subtracted,
|
| 942 |
+
wav_type_set,
|
| 943 |
+
is_save_aligned,
|
| 944 |
+
command_Text,
|
| 945 |
+
save_format,
|
| 946 |
+
align_window: list,
|
| 947 |
+
align_intro_val: list,
|
| 948 |
+
db_analysis: tuple,
|
| 949 |
+
set_progress_bar,
|
| 950 |
+
phase_option,
|
| 951 |
+
phase_shifts,
|
| 952 |
+
is_match_silence,
|
| 953 |
+
is_spec_match,
|
| 954 |
+
):
|
| 955 |
+
|
| 956 |
+
global progress_value
|
| 957 |
+
progress_value = 0
|
| 958 |
+
is_mono = False
|
| 959 |
+
|
| 960 |
+
def get_diff(a, b):
|
| 961 |
+
corr = np.correlate(a, b, "full")
|
| 962 |
+
diff = corr.argmax() - (b.shape[0] - 1)
|
| 963 |
+
|
| 964 |
+
return diff
|
| 965 |
+
|
| 966 |
+
def progress_bar(length):
|
| 967 |
+
global progress_value
|
| 968 |
+
progress_value += 1
|
| 969 |
+
|
| 970 |
+
if (0.90 / length * progress_value) >= 0.9:
|
| 971 |
+
length = progress_value + 1
|
| 972 |
+
|
| 973 |
+
set_progress_bar(0.1, (0.9 / length * progress_value))
|
| 974 |
+
|
| 975 |
+
# read tracks
|
| 976 |
+
|
| 977 |
+
if file1.endswith(".mp3") and is_macos:
|
| 978 |
+
length1 = rerun_mp3(file1)
|
| 979 |
+
wav1, sr1 = librosa.load(file1, duration=length1, sr=44100, mono=False)
|
| 980 |
+
else:
|
| 981 |
+
wav1, sr1 = librosa.load(file1, sr=44100, mono=False)
|
| 982 |
+
|
| 983 |
+
if file2.endswith(".mp3") and is_macos:
|
| 984 |
+
length2 = rerun_mp3(file2)
|
| 985 |
+
wav2, sr2 = librosa.load(file2, duration=length2, sr=44100, mono=False)
|
| 986 |
+
else:
|
| 987 |
+
wav2, sr2 = librosa.load(file2, sr=44100, mono=False)
|
| 988 |
+
|
| 989 |
+
if wav1.ndim == 1 and wav2.ndim == 1:
|
| 990 |
+
is_mono = True
|
| 991 |
+
elif wav1.ndim == 1:
|
| 992 |
+
wav1 = np.asfortranarray([wav1, wav1])
|
| 993 |
+
elif wav2.ndim == 1:
|
| 994 |
+
wav2 = np.asfortranarray([wav2, wav2])
|
| 995 |
+
|
| 996 |
+
# Check if phase is inverted
|
| 997 |
+
if phase_option == AUTO_PHASE:
|
| 998 |
+
if check_if_phase_inverted(wav1, wav2, is_mono=is_mono):
|
| 999 |
+
wav2 = -wav2
|
| 1000 |
+
elif phase_option == POSITIVE_PHASE:
|
| 1001 |
+
wav2 = +wav2
|
| 1002 |
+
elif phase_option == NEGATIVE_PHASE:
|
| 1003 |
+
wav2 = -wav2
|
| 1004 |
+
|
| 1005 |
+
if is_match_silence:
|
| 1006 |
+
wav2 = adjust_leading_silence(wav2, wav1)
|
| 1007 |
+
|
| 1008 |
+
wav1_length = int(librosa.get_duration(y=wav1, sr=44100))
|
| 1009 |
+
wav2_length = int(librosa.get_duration(y=wav2, sr=44100))
|
| 1010 |
+
|
| 1011 |
+
if not is_mono:
|
| 1012 |
+
wav1 = wav1.transpose()
|
| 1013 |
+
wav2 = wav2.transpose()
|
| 1014 |
+
|
| 1015 |
+
wav2_org = wav2.copy()
|
| 1016 |
+
|
| 1017 |
+
command_Text("Processing files... \n")
|
| 1018 |
+
seconds_length = min(wav1_length, wav2_length)
|
| 1019 |
+
|
| 1020 |
+
wav2_aligned_sources = []
|
| 1021 |
+
|
| 1022 |
+
for sec_len in align_intro_val:
|
| 1023 |
+
# pick a position at 1 second in and get diff
|
| 1024 |
+
sec_seg = 1 if sec_len == 1 else int(seconds_length // sec_len)
|
| 1025 |
+
index = sr1 * sec_seg # 1 second in, assuming sr1 = sr2 = 44100
|
| 1026 |
+
|
| 1027 |
+
if is_mono:
|
| 1028 |
+
samp1, samp2 = wav1[index : index + sr1], wav2[index : index + sr1]
|
| 1029 |
+
diff = get_diff(samp1, samp2)
|
| 1030 |
+
# print(f"Estimated difference: {diff}\n")
|
| 1031 |
+
else:
|
| 1032 |
+
index = sr1 * sec_seg # 1 second in, assuming sr1 = sr2 = 44100
|
| 1033 |
+
samp1, samp2 = wav1[index : index + sr1, 0], wav2[index : index + sr1, 0]
|
| 1034 |
+
samp1_r, samp2_r = wav1[index : index + sr1, 1], wav2[index : index + sr1, 1]
|
| 1035 |
+
diff, diff_r = get_diff(samp1, samp2), get_diff(samp1_r, samp2_r)
|
| 1036 |
+
# print(f"Estimated difference Left Channel: {diff}\nEstimated difference Right Channel: {diff_r}\n")
|
| 1037 |
+
|
| 1038 |
+
# make aligned track 2
|
| 1039 |
+
if diff > 0:
|
| 1040 |
+
zeros_to_append = np.zeros(diff) if is_mono else np.zeros((diff, 2))
|
| 1041 |
+
wav2_aligned = np.append(zeros_to_append, wav2_org, axis=0)
|
| 1042 |
+
elif diff < 0:
|
| 1043 |
+
wav2_aligned = wav2_org[-diff:]
|
| 1044 |
+
else:
|
| 1045 |
+
wav2_aligned = wav2_org
|
| 1046 |
+
# command_Text(f"Audio files already aligned.\n")
|
| 1047 |
+
|
| 1048 |
+
if not any(np.array_equal(wav2_aligned, source) for source in wav2_aligned_sources):
|
| 1049 |
+
wav2_aligned_sources.append(wav2_aligned)
|
| 1050 |
+
|
| 1051 |
+
# print("Unique Sources: ", len(wav2_aligned_sources))
|
| 1052 |
+
|
| 1053 |
+
unique_sources = len(wav2_aligned_sources)
|
| 1054 |
+
|
| 1055 |
+
sub_mapper_big_mapper = {}
|
| 1056 |
+
|
| 1057 |
+
for s in wav2_aligned_sources:
|
| 1058 |
+
wav2_aligned = match_mono_array_shapes(s, wav1) if is_mono else match_array_shapes(s, wav1, is_swap=True)
|
| 1059 |
+
|
| 1060 |
+
if align_window:
|
| 1061 |
+
wav_sub = time_correction(
|
| 1062 |
+
wav1, wav2_aligned, seconds_length, align_window=align_window, db_analysis=db_analysis, progress_bar=progress_bar, unique_sources=unique_sources, phase_shifts=phase_shifts
|
| 1063 |
+
)
|
| 1064 |
+
wav_sub_size = np.abs(wav_sub).mean()
|
| 1065 |
+
sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
|
| 1066 |
+
else:
|
| 1067 |
+
wav2_aligned = wav2_aligned * np.power(10, db_analysis[0] / 20)
|
| 1068 |
+
db_range = db_analysis[1]
|
| 1069 |
+
|
| 1070 |
+
for db_adjustment in db_range:
|
| 1071 |
+
# Adjust the dB of track2
|
| 1072 |
+
s_adjusted = wav2_aligned * (10 ** (db_adjustment / 20))
|
| 1073 |
+
wav_sub = wav1 - s_adjusted
|
| 1074 |
+
wav_sub_size = np.abs(wav_sub).mean()
|
| 1075 |
+
sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
|
| 1076 |
+
|
| 1077 |
+
# print(sub_mapper_big_mapper.keys(), min(sub_mapper_big_mapper.keys()))
|
| 1078 |
+
|
| 1079 |
+
sub_mapper_value_list = list(sub_mapper_big_mapper.values())
|
| 1080 |
+
|
| 1081 |
+
if is_spec_match and len(sub_mapper_value_list) >= 2:
|
| 1082 |
+
# print("using spec ensemble with align")
|
| 1083 |
+
wav_sub = ensemble_for_align(list(sub_mapper_big_mapper.values()))
|
| 1084 |
+
else:
|
| 1085 |
+
# print("using linear ensemble with align")
|
| 1086 |
+
wav_sub = ensemble_wav(list(sub_mapper_big_mapper.values()))
|
| 1087 |
+
|
| 1088 |
+
# print(f"Mix Mean: {np.abs(wav1).mean()}\nInst Mean: {np.abs(wav2).mean()}")
|
| 1089 |
+
# print('Final: ', np.abs(wav_sub).mean())
|
| 1090 |
+
wav_sub = np.clip(wav_sub, -1, +1)
|
| 1091 |
+
|
| 1092 |
+
command_Text(f"Saving inverted track... ")
|
| 1093 |
+
|
| 1094 |
+
if is_save_aligned or is_spec_match:
|
| 1095 |
+
wav1 = match_mono_array_shapes(wav1, wav_sub) if is_mono else match_array_shapes(wav1, wav_sub, is_swap=True)
|
| 1096 |
+
wav2_aligned = wav1 - wav_sub
|
| 1097 |
+
|
| 1098 |
+
if is_spec_match:
|
| 1099 |
+
if wav1.ndim == 1 and wav2.ndim == 1:
|
| 1100 |
+
wav2_aligned = np.asfortranarray([wav2_aligned, wav2_aligned]).T
|
| 1101 |
+
wav1 = np.asfortranarray([wav1, wav1]).T
|
| 1102 |
+
|
| 1103 |
+
wav2_aligned = ensemble_for_align([wav2_aligned, wav1])
|
| 1104 |
+
wav_sub = wav1 - wav2_aligned
|
| 1105 |
+
|
| 1106 |
+
if is_save_aligned:
|
| 1107 |
+
sf.write(file2_aligned, wav2_aligned, sr1, subtype=wav_type_set)
|
| 1108 |
+
save_format(file2_aligned)
|
| 1109 |
+
|
| 1110 |
+
sf.write(file_subtracted, wav_sub, sr1, subtype=wav_type_set)
|
| 1111 |
+
save_format(file_subtracted)
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
+
def phase_shift_hilbert(signal, degree):
|
| 1115 |
+
analytic_signal = hilbert(signal)
|
| 1116 |
+
return np.cos(np.radians(degree)) * analytic_signal.real - np.sin(np.radians(degree)) * analytic_signal.imag
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
def get_phase_shifted_tracks(track, phase_shift):
|
| 1120 |
+
if phase_shift == 180:
|
| 1121 |
+
return [track, -track]
|
| 1122 |
+
|
| 1123 |
+
step = phase_shift
|
| 1124 |
+
end = 180 - (180 % step) if 180 % step == 0 else 181
|
| 1125 |
+
phase_range = range(step, end, step)
|
| 1126 |
+
|
| 1127 |
+
flipped_list = [track, -track]
|
| 1128 |
+
for i in phase_range:
|
| 1129 |
+
flipped_list.extend([phase_shift_hilbert(track, i), phase_shift_hilbert(track, -i)])
|
| 1130 |
+
|
| 1131 |
+
return flipped_list
|
| 1132 |
+
|
| 1133 |
+
|
| 1134 |
+
def time_correction(mix: np.ndarray, instrumental: np.ndarray, seconds_length, align_window, db_analysis, sr=44100, progress_bar=None, unique_sources=None, phase_shifts=NONE_P):
|
| 1135 |
+
# Function to align two tracks using cross-correlation
|
| 1136 |
+
|
| 1137 |
+
def align_tracks(track1, track2):
|
| 1138 |
+
# A dictionary to store each version of track2_shifted and its mean absolute value
|
| 1139 |
+
shifted_tracks = {}
|
| 1140 |
+
|
| 1141 |
+
# Loop to adjust dB of track2
|
| 1142 |
+
track2 = track2 * np.power(10, db_analysis[0] / 20)
|
| 1143 |
+
db_range = db_analysis[1]
|
| 1144 |
+
|
| 1145 |
+
if phase_shifts == 190:
|
| 1146 |
+
track2_flipped = [track2]
|
| 1147 |
+
else:
|
| 1148 |
+
track2_flipped = get_phase_shifted_tracks(track2, phase_shifts)
|
| 1149 |
+
|
| 1150 |
+
for db_adjustment in db_range:
|
| 1151 |
+
for t in track2_flipped:
|
| 1152 |
+
# Adjust the dB of track2
|
| 1153 |
+
track2_adjusted = t * (10 ** (db_adjustment / 20))
|
| 1154 |
+
corr = correlate(track1, track2_adjusted)
|
| 1155 |
+
delay = np.argmax(np.abs(corr)) - (len(track1) - 1)
|
| 1156 |
+
track2_shifted = np.roll(track2_adjusted, shift=delay)
|
| 1157 |
+
|
| 1158 |
+
# Compute the mean absolute value of track2_shifted
|
| 1159 |
+
track2_shifted_sub = track1 - track2_shifted
|
| 1160 |
+
mean_abs_value = np.abs(track2_shifted_sub).mean()
|
| 1161 |
+
|
| 1162 |
+
# Store track2_shifted and its mean absolute value in the dictionary
|
| 1163 |
+
shifted_tracks[mean_abs_value] = track2_shifted
|
| 1164 |
+
|
| 1165 |
+
# Return the version of track2_shifted with the smallest mean absolute value
|
| 1166 |
+
|
| 1167 |
+
return shifted_tracks[min(shifted_tracks.keys())]
|
| 1168 |
+
|
| 1169 |
+
# Make sure the audio files have the same shape
|
| 1170 |
+
|
| 1171 |
+
assert mix.shape == instrumental.shape, f"Audio files must have the same shape - Mix: {mix.shape}, Inst: {instrumental.shape}"
|
| 1172 |
+
|
| 1173 |
+
seconds_length = seconds_length // 2
|
| 1174 |
+
|
| 1175 |
+
sub_mapper = {}
|
| 1176 |
+
|
| 1177 |
+
progress_update_interval = 120
|
| 1178 |
+
total_iterations = 0
|
| 1179 |
+
|
| 1180 |
+
if len(align_window) > 2:
|
| 1181 |
+
progress_update_interval = 320
|
| 1182 |
+
|
| 1183 |
+
for secs in align_window:
|
| 1184 |
+
step = secs / 2
|
| 1185 |
+
window_size = int(sr * secs)
|
| 1186 |
+
step_size = int(sr * step)
|
| 1187 |
+
|
| 1188 |
+
if len(mix.shape) == 1:
|
| 1189 |
+
total_mono = (len(range(0, len(mix) - window_size, step_size)) // progress_update_interval) * unique_sources
|
| 1190 |
+
total_iterations += total_mono
|
| 1191 |
+
else:
|
| 1192 |
+
total_stereo_ = len(range(0, len(mix[:, 0]) - window_size, step_size)) * 2
|
| 1193 |
+
total_stereo = (total_stereo_ // progress_update_interval) * unique_sources
|
| 1194 |
+
total_iterations += total_stereo
|
| 1195 |
+
|
| 1196 |
+
# print(total_iterations)
|
| 1197 |
+
|
| 1198 |
+
for secs in align_window:
|
| 1199 |
+
sub = np.zeros_like(mix)
|
| 1200 |
+
divider = np.zeros_like(mix)
|
| 1201 |
+
step = secs / 2
|
| 1202 |
+
window_size = int(sr * secs)
|
| 1203 |
+
step_size = int(sr * step)
|
| 1204 |
+
window = np.hanning(window_size)
|
| 1205 |
+
|
| 1206 |
+
# For the mono case:
|
| 1207 |
+
if len(mix.shape) == 1:
|
| 1208 |
+
# The files are mono
|
| 1209 |
+
counter = 0
|
| 1210 |
+
for i in range(0, len(mix) - window_size, step_size):
|
| 1211 |
+
counter += 1
|
| 1212 |
+
if counter % progress_update_interval == 0:
|
| 1213 |
+
progress_bar(total_iterations)
|
| 1214 |
+
window_mix = mix[i : i + window_size] * window
|
| 1215 |
+
window_instrumental = instrumental[i : i + window_size] * window
|
| 1216 |
+
window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
|
| 1217 |
+
sub[i : i + window_size] += window_mix - window_instrumental_aligned
|
| 1218 |
+
divider[i : i + window_size] += window
|
| 1219 |
+
else:
|
| 1220 |
+
# The files are stereo
|
| 1221 |
+
counter = 0
|
| 1222 |
+
for ch in range(mix.shape[1]):
|
| 1223 |
+
for i in range(0, len(mix[:, ch]) - window_size, step_size):
|
| 1224 |
+
counter += 1
|
| 1225 |
+
if counter % progress_update_interval == 0:
|
| 1226 |
+
progress_bar(total_iterations)
|
| 1227 |
+
window_mix = mix[i : i + window_size, ch] * window
|
| 1228 |
+
window_instrumental = instrumental[i : i + window_size, ch] * window
|
| 1229 |
+
window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
|
| 1230 |
+
sub[i : i + window_size, ch] += window_mix - window_instrumental_aligned
|
| 1231 |
+
divider[i : i + window_size, ch] += window
|
| 1232 |
+
|
| 1233 |
+
# Normalize the result by the overlap count
|
| 1234 |
+
sub = np.where(divider > 1e-6, sub / divider, sub)
|
| 1235 |
+
sub_size = np.abs(sub).mean()
|
| 1236 |
+
sub_mapper = {**sub_mapper, **{sub_size: sub}}
|
| 1237 |
+
|
| 1238 |
+
# print("SUB_LEN", len(list(sub_mapper.values())))
|
| 1239 |
+
|
| 1240 |
+
sub = ensemble_wav(list(sub_mapper.values()), split_size=12)
|
| 1241 |
+
|
| 1242 |
+
return sub
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
def ensemble_wav(waveforms, split_size=240):
|
| 1246 |
+
# Create a dictionary to hold the thirds of each waveform and their mean absolute values
|
| 1247 |
+
waveform_thirds = {i: np.array_split(waveform, split_size) for i, waveform in enumerate(waveforms)}
|
| 1248 |
+
|
| 1249 |
+
# Initialize the final waveform
|
| 1250 |
+
final_waveform = []
|
| 1251 |
+
|
| 1252 |
+
# For chunk
|
| 1253 |
+
for third_idx in range(split_size):
|
| 1254 |
+
# Compute the mean absolute value of each third from each waveform
|
| 1255 |
+
means = [np.abs(waveform_thirds[i][third_idx]).mean() for i in range(len(waveforms))]
|
| 1256 |
+
|
| 1257 |
+
# Find the index of the waveform with the lowest mean absolute value for this third
|
| 1258 |
+
min_index = np.argmin(means)
|
| 1259 |
+
|
| 1260 |
+
# Add the least noisy third to the final waveform
|
| 1261 |
+
final_waveform.append(waveform_thirds[min_index][third_idx])
|
| 1262 |
+
|
| 1263 |
+
# Concatenate all the thirds to create the final waveform
|
| 1264 |
+
final_waveform = np.concatenate(final_waveform)
|
| 1265 |
+
|
| 1266 |
+
return final_waveform
|
| 1267 |
+
|
| 1268 |
+
|
| 1269 |
+
def ensemble_wav_min(waveforms):
|
| 1270 |
+
for i in range(1, len(waveforms)):
|
| 1271 |
+
if i == 1:
|
| 1272 |
+
wave = waveforms[0]
|
| 1273 |
+
|
| 1274 |
+
ln = min(len(wave), len(waveforms[i]))
|
| 1275 |
+
wave = wave[:ln]
|
| 1276 |
+
waveforms[i] = waveforms[i][:ln]
|
| 1277 |
+
|
| 1278 |
+
wave = np.where(np.abs(waveforms[i]) <= np.abs(wave), waveforms[i], wave)
|
| 1279 |
+
|
| 1280 |
+
return wave
|
| 1281 |
+
|
| 1282 |
+
|
| 1283 |
+
def align_audio_test(wav1, wav2, sr1=44100):
|
| 1284 |
+
def get_diff(a, b):
|
| 1285 |
+
corr = np.correlate(a, b, "full")
|
| 1286 |
+
diff = corr.argmax() - (b.shape[0] - 1)
|
| 1287 |
+
return diff
|
| 1288 |
+
|
| 1289 |
+
# read tracks
|
| 1290 |
+
wav1 = wav1.transpose()
|
| 1291 |
+
wav2 = wav2.transpose()
|
| 1292 |
+
|
| 1293 |
+
# print(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n")
|
| 1294 |
+
|
| 1295 |
+
wav2_org = wav2.copy()
|
| 1296 |
+
|
| 1297 |
+
# pick a position at 1 second in and get diff
|
| 1298 |
+
index = sr1 # *seconds_length # 1 second in, assuming sr1 = sr2 = 44100
|
| 1299 |
+
samp1 = wav1[index : index + sr1, 0] # currently use left channel
|
| 1300 |
+
samp2 = wav2[index : index + sr1, 0]
|
| 1301 |
+
diff = get_diff(samp1, samp2)
|
| 1302 |
+
|
| 1303 |
+
# make aligned track 2
|
| 1304 |
+
if diff > 0:
|
| 1305 |
+
wav2_aligned = np.append(np.zeros((diff, 1)), wav2_org, axis=0)
|
| 1306 |
+
elif diff < 0:
|
| 1307 |
+
wav2_aligned = wav2_org[-diff:]
|
| 1308 |
+
else:
|
| 1309 |
+
wav2_aligned = wav2_org
|
| 1310 |
+
|
| 1311 |
+
return wav2_aligned
|
| 1312 |
+
|
| 1313 |
+
|
| 1314 |
+
def load_audio(audio_file):
|
| 1315 |
+
wav, sr = librosa.load(audio_file, sr=44100, mono=False)
|
| 1316 |
+
|
| 1317 |
+
if wav.ndim == 1:
|
| 1318 |
+
wav = np.asfortranarray([wav, wav])
|
| 1319 |
+
|
| 1320 |
+
return wav
|
| 1321 |
+
|
| 1322 |
+
|
| 1323 |
+
def rerun_mp3(audio_file):
|
| 1324 |
+
with audioread.audio_open(audio_file) as f:
|
| 1325 |
+
track_length = int(f.duration)
|
| 1326 |
+
|
| 1327 |
+
return track_length
|
audio_separator/separator/uvr_lib_v5/stft.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class STFT:
|
| 5 |
+
"""
|
| 6 |
+
This class performs the Short-Time Fourier Transform (STFT) and its inverse (ISTFT).
|
| 7 |
+
These functions are essential for converting the audio between the time domain and the frequency domain,
|
| 8 |
+
which is a crucial aspect of audio processing in neural networks.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, logger, n_fft, hop_length, dim_f, device):
|
| 12 |
+
self.logger = logger
|
| 13 |
+
self.n_fft = n_fft
|
| 14 |
+
self.hop_length = hop_length
|
| 15 |
+
self.dim_f = dim_f
|
| 16 |
+
self.device = device
|
| 17 |
+
# Create a Hann window tensor for use in the STFT.
|
| 18 |
+
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
| 19 |
+
|
| 20 |
+
def __call__(self, input_tensor):
|
| 21 |
+
# Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
|
| 22 |
+
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
| 23 |
+
|
| 24 |
+
# If on a non-standard device, temporarily move the tensor to CPU for processing.
|
| 25 |
+
if is_non_standard_device:
|
| 26 |
+
input_tensor = input_tensor.cpu()
|
| 27 |
+
|
| 28 |
+
# Transfer the pre-defined window tensor to the same device as the input tensor.
|
| 29 |
+
stft_window = self.hann_window.to(input_tensor.device)
|
| 30 |
+
|
| 31 |
+
# Extract batch dimensions (all dimensions except the last two which are channel and time).
|
| 32 |
+
batch_dimensions = input_tensor.shape[:-2]
|
| 33 |
+
|
| 34 |
+
# Extract channel and time dimensions (last two dimensions of the tensor).
|
| 35 |
+
channel_dim, time_dim = input_tensor.shape[-2:]
|
| 36 |
+
|
| 37 |
+
# Reshape the tensor to merge batch and channel dimensions for STFT processing.
|
| 38 |
+
reshaped_tensor = input_tensor.reshape([-1, time_dim])
|
| 39 |
+
|
| 40 |
+
# Perform the Short-Time Fourier Transform (STFT) on the reshaped tensor.
|
| 41 |
+
stft_output = torch.stft(reshaped_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True, return_complex=False)
|
| 42 |
+
|
| 43 |
+
# Rearrange the dimensions of the STFT output to bring the frequency dimension forward.
|
| 44 |
+
permuted_stft_output = stft_output.permute([0, 3, 1, 2])
|
| 45 |
+
|
| 46 |
+
# Reshape the output to restore the original batch and channel dimensions, while keeping the newly formed frequency and time dimensions.
|
| 47 |
+
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape(
|
| 48 |
+
[*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]]
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# If the original tensor was on a non-standard device, move the processed tensor back to that device.
|
| 52 |
+
if is_non_standard_device:
|
| 53 |
+
final_output = final_output.to(self.device)
|
| 54 |
+
|
| 55 |
+
# Return the transformed tensor, sliced to retain only the required frequency dimension (`dim_f`).
|
| 56 |
+
return final_output[..., : self.dim_f, :]
|
| 57 |
+
|
| 58 |
+
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
|
| 59 |
+
"""
|
| 60 |
+
Adds zero padding to the frequency dimension of the input tensor.
|
| 61 |
+
"""
|
| 62 |
+
# Create a padding tensor for the frequency dimension
|
| 63 |
+
freq_padding = torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)
|
| 64 |
+
|
| 65 |
+
# Concatenate the padding to the input tensor along the frequency dimension.
|
| 66 |
+
padded_tensor = torch.cat([input_tensor, freq_padding], -2)
|
| 67 |
+
|
| 68 |
+
return padded_tensor
|
| 69 |
+
|
| 70 |
+
def calculate_inverse_dimensions(self, input_tensor):
|
| 71 |
+
# Extract batch dimensions and frequency-time dimensions.
|
| 72 |
+
batch_dimensions = input_tensor.shape[:-3]
|
| 73 |
+
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
|
| 74 |
+
|
| 75 |
+
# Calculate the number of frequency bins for the inverse STFT.
|
| 76 |
+
num_freq_bins = self.n_fft // 2 + 1
|
| 77 |
+
|
| 78 |
+
return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins
|
| 79 |
+
|
| 80 |
+
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
|
| 81 |
+
"""
|
| 82 |
+
Prepares the tensor for Inverse Short-Time Fourier Transform (ISTFT) by reshaping
|
| 83 |
+
and creating a complex tensor from the real and imaginary parts.
|
| 84 |
+
"""
|
| 85 |
+
# Reshape the tensor to separate real and imaginary parts and prepare for ISTFT.
|
| 86 |
+
reshaped_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim])
|
| 87 |
+
|
| 88 |
+
# Flatten batch dimensions and rearrange for ISTFT.
|
| 89 |
+
flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim])
|
| 90 |
+
|
| 91 |
+
# Rearrange the dimensions of the tensor to bring the frequency dimension forward.
|
| 92 |
+
permuted_tensor = flattened_tensor.permute([0, 2, 3, 1])
|
| 93 |
+
|
| 94 |
+
# Combine real and imaginary parts into a complex tensor.
|
| 95 |
+
complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
|
| 96 |
+
|
| 97 |
+
return complex_tensor
|
| 98 |
+
|
| 99 |
+
def inverse(self, input_tensor):
|
| 100 |
+
# Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
|
| 101 |
+
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
| 102 |
+
|
| 103 |
+
# If on a non-standard device, temporarily move the tensor to CPU for processing.
|
| 104 |
+
if is_non_standard_device:
|
| 105 |
+
input_tensor = input_tensor.cpu()
|
| 106 |
+
|
| 107 |
+
# Transfer the pre-defined Hann window tensor to the same device as the input tensor.
|
| 108 |
+
stft_window = self.hann_window.to(input_tensor.device)
|
| 109 |
+
|
| 110 |
+
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
|
| 111 |
+
|
| 112 |
+
padded_tensor = self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins)
|
| 113 |
+
|
| 114 |
+
complex_tensor = self.prepare_for_istft(padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim)
|
| 115 |
+
|
| 116 |
+
# Perform the Inverse Short-Time Fourier Transform (ISTFT).
|
| 117 |
+
istft_result = torch.istft(complex_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True)
|
| 118 |
+
|
| 119 |
+
# Reshape ISTFT result to restore original batch and channel dimensions.
|
| 120 |
+
final_output = istft_result.reshape([*batch_dimensions, 2, -1])
|
| 121 |
+
|
| 122 |
+
# If the original tensor was on a non-standard device, move the processed tensor back to that device.
|
| 123 |
+
if is_non_standard_device:
|
| 124 |
+
final_output = final_output.to(self.device)
|
| 125 |
+
|
| 126 |
+
return final_output
|
audio_separator/separator/uvr_lib_v5/tfc_tdf_v3.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
class STFT:
|
| 6 |
+
def __init__(self, n_fft, hop_length, dim_f, device):
|
| 7 |
+
self.n_fft = n_fft
|
| 8 |
+
self.hop_length = hop_length
|
| 9 |
+
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
| 10 |
+
self.dim_f = dim_f
|
| 11 |
+
self.device = device
|
| 12 |
+
|
| 13 |
+
def __call__(self, x):
|
| 14 |
+
|
| 15 |
+
x_is_mps = not x.device.type in ["cuda", "cpu"]
|
| 16 |
+
if x_is_mps:
|
| 17 |
+
x = x.cpu()
|
| 18 |
+
|
| 19 |
+
window = self.window.to(x.device)
|
| 20 |
+
batch_dims = x.shape[:-2]
|
| 21 |
+
c, t = x.shape[-2:]
|
| 22 |
+
x = x.reshape([-1, t])
|
| 23 |
+
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True,return_complex=False)
|
| 24 |
+
x = x.permute([0, 3, 1, 2])
|
| 25 |
+
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
|
| 26 |
+
|
| 27 |
+
if x_is_mps:
|
| 28 |
+
x = x.to(self.device)
|
| 29 |
+
|
| 30 |
+
return x[..., :self.dim_f, :]
|
| 31 |
+
|
| 32 |
+
def inverse(self, x):
|
| 33 |
+
|
| 34 |
+
x_is_mps = not x.device.type in ["cuda", "cpu"]
|
| 35 |
+
if x_is_mps:
|
| 36 |
+
x = x.cpu()
|
| 37 |
+
|
| 38 |
+
window = self.window.to(x.device)
|
| 39 |
+
batch_dims = x.shape[:-3]
|
| 40 |
+
c, f, t = x.shape[-3:]
|
| 41 |
+
n = self.n_fft // 2 + 1
|
| 42 |
+
f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
|
| 43 |
+
x = torch.cat([x, f_pad], -2)
|
| 44 |
+
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
|
| 45 |
+
x = x.permute([0, 2, 3, 1])
|
| 46 |
+
x = x[..., 0] + x[..., 1] * 1.j
|
| 47 |
+
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
|
| 48 |
+
x = x.reshape([*batch_dims, 2, -1])
|
| 49 |
+
|
| 50 |
+
if x_is_mps:
|
| 51 |
+
x = x.to(self.device)
|
| 52 |
+
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
def get_norm(norm_type):
|
| 56 |
+
def norm(c, norm_type):
|
| 57 |
+
if norm_type == 'BatchNorm':
|
| 58 |
+
return nn.BatchNorm2d(c)
|
| 59 |
+
elif norm_type == 'InstanceNorm':
|
| 60 |
+
return nn.InstanceNorm2d(c, affine=True)
|
| 61 |
+
elif 'GroupNorm' in norm_type:
|
| 62 |
+
g = int(norm_type.replace('GroupNorm', ''))
|
| 63 |
+
return nn.GroupNorm(num_groups=g, num_channels=c)
|
| 64 |
+
else:
|
| 65 |
+
return nn.Identity()
|
| 66 |
+
|
| 67 |
+
return partial(norm, norm_type=norm_type)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_act(act_type):
|
| 71 |
+
if act_type == 'gelu':
|
| 72 |
+
return nn.GELU()
|
| 73 |
+
elif act_type == 'relu':
|
| 74 |
+
return nn.ReLU()
|
| 75 |
+
elif act_type[:3] == 'elu':
|
| 76 |
+
alpha = float(act_type.replace('elu', ''))
|
| 77 |
+
return nn.ELU(alpha)
|
| 78 |
+
else:
|
| 79 |
+
raise Exception
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Upscale(nn.Module):
|
| 83 |
+
def __init__(self, in_c, out_c, scale, norm, act):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.conv = nn.Sequential(
|
| 86 |
+
norm(in_c),
|
| 87 |
+
act,
|
| 88 |
+
nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
return self.conv(x)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Downscale(nn.Module):
|
| 96 |
+
def __init__(self, in_c, out_c, scale, norm, act):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.conv = nn.Sequential(
|
| 99 |
+
norm(in_c),
|
| 100 |
+
act,
|
| 101 |
+
nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
return self.conv(x)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class TFC_TDF(nn.Module):
|
| 109 |
+
def __init__(self, in_c, c, l, f, bn, norm, act):
|
| 110 |
+
super().__init__()
|
| 111 |
+
|
| 112 |
+
self.blocks = nn.ModuleList()
|
| 113 |
+
for i in range(l):
|
| 114 |
+
block = nn.Module()
|
| 115 |
+
|
| 116 |
+
block.tfc1 = nn.Sequential(
|
| 117 |
+
norm(in_c),
|
| 118 |
+
act,
|
| 119 |
+
nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
|
| 120 |
+
)
|
| 121 |
+
block.tdf = nn.Sequential(
|
| 122 |
+
norm(c),
|
| 123 |
+
act,
|
| 124 |
+
nn.Linear(f, f // bn, bias=False),
|
| 125 |
+
norm(c),
|
| 126 |
+
act,
|
| 127 |
+
nn.Linear(f // bn, f, bias=False),
|
| 128 |
+
)
|
| 129 |
+
block.tfc2 = nn.Sequential(
|
| 130 |
+
norm(c),
|
| 131 |
+
act,
|
| 132 |
+
nn.Conv2d(c, c, 3, 1, 1, bias=False),
|
| 133 |
+
)
|
| 134 |
+
block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
|
| 135 |
+
|
| 136 |
+
self.blocks.append(block)
|
| 137 |
+
in_c = c
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
for block in self.blocks:
|
| 141 |
+
s = block.shortcut(x)
|
| 142 |
+
x = block.tfc1(x)
|
| 143 |
+
x = x + block.tdf(x)
|
| 144 |
+
x = block.tfc2(x)
|
| 145 |
+
x = x + s
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class TFC_TDF_net(nn.Module):
|
| 150 |
+
def __init__(self, config, device):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.config = config
|
| 153 |
+
self.device = device
|
| 154 |
+
|
| 155 |
+
norm = get_norm(norm_type=config.model.norm)
|
| 156 |
+
act = get_act(act_type=config.model.act)
|
| 157 |
+
|
| 158 |
+
self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments)
|
| 159 |
+
self.num_subbands = config.model.num_subbands
|
| 160 |
+
|
| 161 |
+
dim_c = self.num_subbands * config.audio.num_channels * 2
|
| 162 |
+
n = config.model.num_scales
|
| 163 |
+
scale = config.model.scale
|
| 164 |
+
l = config.model.num_blocks_per_scale
|
| 165 |
+
c = config.model.num_channels
|
| 166 |
+
g = config.model.growth
|
| 167 |
+
bn = config.model.bottleneck_factor
|
| 168 |
+
f = config.audio.dim_f // self.num_subbands
|
| 169 |
+
|
| 170 |
+
self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
|
| 171 |
+
|
| 172 |
+
self.encoder_blocks = nn.ModuleList()
|
| 173 |
+
for i in range(n):
|
| 174 |
+
block = nn.Module()
|
| 175 |
+
block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
|
| 176 |
+
block.downscale = Downscale(c, c + g, scale, norm, act)
|
| 177 |
+
f = f // scale[1]
|
| 178 |
+
c += g
|
| 179 |
+
self.encoder_blocks.append(block)
|
| 180 |
+
|
| 181 |
+
self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
|
| 182 |
+
|
| 183 |
+
self.decoder_blocks = nn.ModuleList()
|
| 184 |
+
for i in range(n):
|
| 185 |
+
block = nn.Module()
|
| 186 |
+
block.upscale = Upscale(c, c - g, scale, norm, act)
|
| 187 |
+
f = f * scale[1]
|
| 188 |
+
c -= g
|
| 189 |
+
block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act)
|
| 190 |
+
self.decoder_blocks.append(block)
|
| 191 |
+
|
| 192 |
+
self.final_conv = nn.Sequential(
|
| 193 |
+
nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
|
| 194 |
+
act,
|
| 195 |
+
nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
self.stft = STFT(config.audio.n_fft, config.audio.hop_length, config.audio.dim_f, self.device)
|
| 199 |
+
|
| 200 |
+
def cac2cws(self, x):
|
| 201 |
+
k = self.num_subbands
|
| 202 |
+
b, c, f, t = x.shape
|
| 203 |
+
x = x.reshape(b, c, k, f // k, t)
|
| 204 |
+
x = x.reshape(b, c * k, f // k, t)
|
| 205 |
+
return x
|
| 206 |
+
|
| 207 |
+
def cws2cac(self, x):
|
| 208 |
+
k = self.num_subbands
|
| 209 |
+
b, c, f, t = x.shape
|
| 210 |
+
x = x.reshape(b, c // k, k, f, t)
|
| 211 |
+
x = x.reshape(b, c // k, f * k, t)
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
|
| 216 |
+
x = self.stft(x)
|
| 217 |
+
|
| 218 |
+
mix = x = self.cac2cws(x)
|
| 219 |
+
|
| 220 |
+
first_conv_out = x = self.first_conv(x)
|
| 221 |
+
|
| 222 |
+
x = x.transpose(-1, -2)
|
| 223 |
+
|
| 224 |
+
encoder_outputs = []
|
| 225 |
+
for block in self.encoder_blocks:
|
| 226 |
+
x = block.tfc_tdf(x)
|
| 227 |
+
encoder_outputs.append(x)
|
| 228 |
+
x = block.downscale(x)
|
| 229 |
+
|
| 230 |
+
x = self.bottleneck_block(x)
|
| 231 |
+
|
| 232 |
+
for block in self.decoder_blocks:
|
| 233 |
+
x = block.upscale(x)
|
| 234 |
+
x = torch.cat([x, encoder_outputs.pop()], 1)
|
| 235 |
+
x = block.tfc_tdf(x)
|
| 236 |
+
|
| 237 |
+
x = x.transpose(-1, -2)
|
| 238 |
+
|
| 239 |
+
x = x * first_conv_out # reduce artifacts
|
| 240 |
+
|
| 241 |
+
x = self.final_conv(torch.cat([mix, x], 1))
|
| 242 |
+
|
| 243 |
+
x = self.cws2cac(x)
|
| 244 |
+
|
| 245 |
+
if self.num_target_instruments > 1:
|
| 246 |
+
b, c, f, t = x.shape
|
| 247 |
+
x = x.reshape(b, self.num_target_instruments, -1, f, t)
|
| 248 |
+
|
| 249 |
+
x = self.stft.inverse(x)
|
| 250 |
+
|
| 251 |
+
return x
|
| 252 |
+
|
| 253 |
+
|
audio_separator/separator/uvr_lib_v5/vr_network/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# VR init.
|
audio_separator/separator/uvr_lib_v5/vr_network/layers.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Conv2DBNActiv(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
This class implements a convolutional layer followed by batch normalization and an activation function.
|
| 11 |
+
It is a common pattern in deep learning for processing images or feature maps. The convolutional layer
|
| 12 |
+
applies a set of learnable filters to the input. Batch normalization then normalizes the output of the
|
| 13 |
+
convolution, and finally, an activation function introduces non-linearity to the model, allowing it to
|
| 14 |
+
learn more complex patterns.
|
| 15 |
+
|
| 16 |
+
Attributes:
|
| 17 |
+
conv (nn.Sequential): A sequential container of Conv2d, BatchNorm2d, and an activation layer.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
num_input_channels (int): Number of input channels.
|
| 21 |
+
num_output_channels (int): Number of output channels.
|
| 22 |
+
kernel_size (int, optional): Size of the kernel. Defaults to 3.
|
| 23 |
+
stride_length (int, optional): Stride of the convolution. Defaults to 1.
|
| 24 |
+
padding_size (int, optional): Padding added to all sides of the input. Defaults to 1.
|
| 25 |
+
dilation_rate (int, optional): Spacing between kernel elements. Defaults to 1.
|
| 26 |
+
activation_function (callable, optional): The activation function to use. Defaults to nn.ReLU.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
|
| 30 |
+
super(Conv2DBNActiv, self).__init__()
|
| 31 |
+
|
| 32 |
+
# The nn.Sequential container allows us to stack the Conv2d, BatchNorm2d, and activation layers
|
| 33 |
+
# into a single module, simplifying the forward pass.
|
| 34 |
+
self.conv = nn.Sequential(nn.Conv2d(nin, nout, kernel_size=ksize, stride=stride, padding=pad, dilation=dilation, bias=False), nn.BatchNorm2d(nout), activ())
|
| 35 |
+
|
| 36 |
+
def __call__(self, input_tensor):
|
| 37 |
+
# Defines the computation performed at every call.
|
| 38 |
+
# Simply passes the input through the sequential container.
|
| 39 |
+
return self.conv(input_tensor)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SeperableConv2DBNActiv(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
This class implements a separable convolutional layer followed by batch normalization and an activation function.
|
| 45 |
+
Separable convolutions are a type of convolution that splits the convolution operation into two simpler operations:
|
| 46 |
+
a depthwise convolution and a pointwise convolution. This can reduce the number of parameters and computational cost,
|
| 47 |
+
making the network more efficient while maintaining similar performance.
|
| 48 |
+
|
| 49 |
+
The depthwise convolution applies a single filter per input channel (input depth). The pointwise convolution,
|
| 50 |
+
which follows, applies a 1x1 convolution to combine the outputs of the depthwise convolution across channels.
|
| 51 |
+
Batch normalization is then applied to stabilize learning and reduce internal covariate shift. Finally,
|
| 52 |
+
an activation function introduces non-linearity, allowing the network to learn complex patterns.
|
| 53 |
+
Attributes:
|
| 54 |
+
conv (nn.Sequential): A sequential container of depthwise Conv2d, pointwise Conv2d, BatchNorm2d, and an activation layer.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
num_input_channels (int): Number of input channels.
|
| 58 |
+
num_output_channels (int): Number of output channels.
|
| 59 |
+
kernel_size (int, optional): Size of the kernel for the depthwise convolution. Defaults to 3.
|
| 60 |
+
stride_length (int, optional): Stride of the convolution. Defaults to 1.
|
| 61 |
+
padding_size (int, optional): Padding added to all sides of the input for the depthwise convolution. Defaults to 1.
|
| 62 |
+
dilation_rate (int, optional): Spacing between kernel elements for the depthwise convolution. Defaults to 1.
|
| 63 |
+
activation_function (callable, optional): The activation function to use. Defaults to nn.ReLU.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
|
| 67 |
+
super(SeperableConv2DBNActiv, self).__init__()
|
| 68 |
+
|
| 69 |
+
# Initialize the sequential container with the depthwise convolution.
|
| 70 |
+
# The number of groups in the depthwise convolution is set to num_input_channels, which means each input channel is treated separately.
|
| 71 |
+
# The pointwise convolution then combines these separate channels into num_output_channels channels.
|
| 72 |
+
# Batch normalization is applied to the output of the pointwise convolution.
|
| 73 |
+
# Finally, the activation function is applied to introduce non-linearity.
|
| 74 |
+
self.conv = nn.Sequential(
|
| 75 |
+
nn.Conv2d(
|
| 76 |
+
nin,
|
| 77 |
+
nin, # For depthwise convolution, in_channels = out_channels = num_input_channels
|
| 78 |
+
kernel_size=ksize,
|
| 79 |
+
stride=stride,
|
| 80 |
+
padding=pad,
|
| 81 |
+
dilation=dilation,
|
| 82 |
+
groups=nin, # This makes it a depthwise convolution
|
| 83 |
+
bias=False, # Bias is not used because it will be handled by BatchNorm2d
|
| 84 |
+
),
|
| 85 |
+
nn.Conv2d(
|
| 86 |
+
nin,
|
| 87 |
+
nout, # Pointwise convolution to combine channels
|
| 88 |
+
kernel_size=1, # Kernel size of 1 for pointwise convolution
|
| 89 |
+
bias=False, # Bias is not used because it will be handled by BatchNorm2d
|
| 90 |
+
),
|
| 91 |
+
nn.BatchNorm2d(nout), # Normalize the output of the pointwise convolution
|
| 92 |
+
activ(), # Apply the activation function
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def __call__(self, input_tensor):
|
| 96 |
+
# Pass the input through the sequential container.
|
| 97 |
+
# This performs the depthwise convolution, followed by the pointwise convolution,
|
| 98 |
+
# batch normalization, and finally applies the activation function.
|
| 99 |
+
return self.conv(input_tensor)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Encoder(nn.Module):
|
| 103 |
+
"""
|
| 104 |
+
The Encoder class is a part of the neural network architecture that is responsible for processing the input data.
|
| 105 |
+
It consists of two convolutional layers, each followed by batch normalization and an activation function.
|
| 106 |
+
The purpose of the Encoder is to transform the input data into a higher-level, abstract representation.
|
| 107 |
+
This is achieved by applying filters (through convolutions) that can capture patterns or features in the data.
|
| 108 |
+
The Encoder can be thought of as a feature extractor that prepares the data for further processing by the network.
|
| 109 |
+
Attributes:
|
| 110 |
+
conv1 (Conv2DBNActiv): The first convolutional layer in the encoder.
|
| 111 |
+
conv2 (Conv2DBNActiv): The second convolutional layer in the encoder.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
number_of_input_channels (int): Number of input channels for the first convolutional layer.
|
| 115 |
+
number_of_output_channels (int): Number of output channels for the convolutional layers.
|
| 116 |
+
kernel_size (int): Kernel size for the convolutional layers.
|
| 117 |
+
stride_length (int): Stride for the convolutional operations.
|
| 118 |
+
padding_size (int): Padding added to all sides of the input for the convolutional layers.
|
| 119 |
+
activation_function (callable): The activation function to use after each convolutional layer.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
|
| 123 |
+
super(Encoder, self).__init__()
|
| 124 |
+
|
| 125 |
+
# The first convolutional layer takes the input and applies a convolution,
|
| 126 |
+
# followed by batch normalization and an activation function specified by `activation_function`.
|
| 127 |
+
# This layer is responsible for capturing the initial set of features from the input data.
|
| 128 |
+
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
|
| 129 |
+
|
| 130 |
+
# The second convolutional layer further processes the output from the first layer,
|
| 131 |
+
# applying another set of convolution, batch normalization, and activation.
|
| 132 |
+
# This layer helps in capturing more complex patterns in the data by building upon the initial features extracted by conv1.
|
| 133 |
+
self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
|
| 134 |
+
|
| 135 |
+
def __call__(self, input_tensor):
|
| 136 |
+
# The input data `input_tensor` is passed through the first convolutional layer.
|
| 137 |
+
# The output of this layer serves as a 'skip connection' that can be used later in the network to preserve spatial information.
|
| 138 |
+
skip = self.conv1(input_tensor)
|
| 139 |
+
|
| 140 |
+
# The output from the first layer is then passed through the second convolutional layer.
|
| 141 |
+
# This processed data `hidden` is the final output of the Encoder, representing the abstracted features of the input.
|
| 142 |
+
hidden = self.conv2(skip)
|
| 143 |
+
|
| 144 |
+
# The Encoder returns two outputs: `hidden`, the abstracted feature representation, and `skip`, the intermediate representation from conv1.
|
| 145 |
+
return hidden, skip
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Decoder(nn.Module):
|
| 149 |
+
"""
|
| 150 |
+
The Decoder class is part of the neural network architecture, specifically designed to perform the inverse operation of an encoder.
|
| 151 |
+
Its main role is to reconstruct or generate data from encoded representations, which is crucial in tasks like image segmentation or audio processing.
|
| 152 |
+
This class uses upsampling, convolution, optional dropout for regularization, and concatenation of skip connections to achieve its goal.
|
| 153 |
+
|
| 154 |
+
Attributes:
|
| 155 |
+
convolution (Conv2DBNActiv): A convolutional layer with batch normalization and activation function.
|
| 156 |
+
dropout_layer (nn.Dropout2d): An optional dropout layer for regularization to prevent overfitting.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
input_channels (int): Number of input channels for the convolutional layer.
|
| 160 |
+
output_channels (int): Number of output channels for the convolutional layer.
|
| 161 |
+
kernel_size (int): Kernel size for the convolutional layer.
|
| 162 |
+
stride (int): Stride for the convolutional operations.
|
| 163 |
+
padding (int): Padding added to all sides of the input for the convolutional layer.
|
| 164 |
+
activation_function (callable): The activation function to use after the convolutional layer.
|
| 165 |
+
include_dropout (bool): Whether to include a dropout layer for regularization.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
|
| 169 |
+
super(Decoder, self).__init__()
|
| 170 |
+
|
| 171 |
+
# Initialize the convolutional layer with specified parameters.
|
| 172 |
+
self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
|
| 173 |
+
|
| 174 |
+
# Initialize the dropout layer if include_dropout is set to True
|
| 175 |
+
self.dropout = nn.Dropout2d(0.1) if dropout else None
|
| 176 |
+
|
| 177 |
+
def __call__(self, input_tensor, skip=None):
|
| 178 |
+
# Upsample the input tensor to a higher resolution using bilinear interpolation.
|
| 179 |
+
input_tensor = F.interpolate(input_tensor, scale_factor=2, mode="bilinear", align_corners=True)
|
| 180 |
+
# If a skip connection is provided, crop it to match the size of input_tensor and concatenate them along the channel dimension.
|
| 181 |
+
if skip is not None:
|
| 182 |
+
skip = spec_utils.crop_center(skip, input_tensor) # Crop skip_connection to match input_tensor's dimensions.
|
| 183 |
+
input_tensor = torch.cat([input_tensor, skip], dim=1) # Concatenate input_tensor and skip_connection along the channel dimension.
|
| 184 |
+
|
| 185 |
+
# Pass the concatenated tensor (or just input_tensor if no skip_connection is provided) through the convolutional layer.
|
| 186 |
+
output_tensor = self.conv(input_tensor)
|
| 187 |
+
|
| 188 |
+
# If dropout is enabled, apply it to the output of the convolutional layer.
|
| 189 |
+
if self.dropout is not None:
|
| 190 |
+
output_tensor = self.dropout(output_tensor)
|
| 191 |
+
|
| 192 |
+
# Return the final output tensor.
|
| 193 |
+
return output_tensor
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class ASPPModule(nn.Module):
|
| 197 |
+
"""
|
| 198 |
+
Atrous Spatial Pyramid Pooling (ASPP) Module is designed for capturing multi-scale context by applying
|
| 199 |
+
atrous convolution at multiple rates. This is particularly useful in segmentation tasks where capturing
|
| 200 |
+
objects at various scales is beneficial. The module applies several parallel dilated convolutions with
|
| 201 |
+
different dilation rates to the input feature map, allowing it to efficiently capture information at
|
| 202 |
+
multiple scales.
|
| 203 |
+
|
| 204 |
+
Attributes:
|
| 205 |
+
conv1 (nn.Sequential): Applies adaptive average pooling followed by a 1x1 convolution.
|
| 206 |
+
nn_architecture (int): Identifier for the neural network architecture being used.
|
| 207 |
+
six_layer (list): List containing architecture identifiers that require six layers.
|
| 208 |
+
seven_layer (list): List containing architecture identifiers that require seven layers.
|
| 209 |
+
conv2-conv7 (nn.Module): Convolutional layers with varying dilation rates for multi-scale feature extraction.
|
| 210 |
+
bottleneck (nn.Sequential): A 1x1 convolutional layer that combines all features followed by dropout for regularization.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(self, nn_architecture, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU):
|
| 214 |
+
"""
|
| 215 |
+
Initializes the ASPP module with specified parameters.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
nn_architecture (int): Identifier for the neural network architecture.
|
| 219 |
+
input_channels (int): Number of input channels.
|
| 220 |
+
output_channels (int): Number of output channels.
|
| 221 |
+
dilations (tuple): Tuple of dilation rates for the atrous convolutions.
|
| 222 |
+
activation (callable): Activation function to use after convolutional layers.
|
| 223 |
+
"""
|
| 224 |
+
super(ASPPModule, self).__init__()
|
| 225 |
+
|
| 226 |
+
# Adaptive average pooling reduces the spatial dimensions to 1x1, focusing on global context,
|
| 227 |
+
# followed by a 1x1 convolution to project back to the desired channel dimension.
|
| 228 |
+
self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, None)), Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ))
|
| 229 |
+
|
| 230 |
+
self.nn_architecture = nn_architecture
|
| 231 |
+
# Architecture identifiers for models requiring additional layers.
|
| 232 |
+
self.six_layer = [129605]
|
| 233 |
+
self.seven_layer = [537238, 537227, 33966]
|
| 234 |
+
|
| 235 |
+
# Extra convolutional layer used for six and seven layer configurations.
|
| 236 |
+
extra_conv = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
|
| 237 |
+
|
| 238 |
+
# Standard 1x1 convolution for channel reduction.
|
| 239 |
+
self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
|
| 240 |
+
|
| 241 |
+
# Separable convolutions with different dilation rates for multi-scale feature extraction.
|
| 242 |
+
self.conv3 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
|
| 243 |
+
self.conv4 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
|
| 244 |
+
self.conv5 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
|
| 245 |
+
|
| 246 |
+
# Depending on the architecture, include the extra convolutional layers.
|
| 247 |
+
if self.nn_architecture in self.six_layer:
|
| 248 |
+
self.conv6 = extra_conv
|
| 249 |
+
nin_x = 6
|
| 250 |
+
elif self.nn_architecture in self.seven_layer:
|
| 251 |
+
self.conv6 = extra_conv
|
| 252 |
+
self.conv7 = extra_conv
|
| 253 |
+
nin_x = 7
|
| 254 |
+
else:
|
| 255 |
+
nin_x = 5
|
| 256 |
+
|
| 257 |
+
# Bottleneck layer combines all the multi-scale features into the desired number of output channels.
|
| 258 |
+
self.bottleneck = nn.Sequential(Conv2DBNActiv(nin * nin_x, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1))
|
| 259 |
+
|
| 260 |
+
def forward(self, input_tensor):
|
| 261 |
+
"""
|
| 262 |
+
Forward pass of the ASPP module.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
input_tensor (Tensor): Input tensor.
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Tensor: Output tensor after applying ASPP.
|
| 269 |
+
"""
|
| 270 |
+
_, _, h, w = input_tensor.size()
|
| 271 |
+
|
| 272 |
+
# Apply the first convolutional sequence and upsample to the original resolution.
|
| 273 |
+
feat1 = F.interpolate(self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True)
|
| 274 |
+
|
| 275 |
+
# Apply the remaining convolutions directly on the input.
|
| 276 |
+
feat2 = self.conv2(input_tensor)
|
| 277 |
+
feat3 = self.conv3(input_tensor)
|
| 278 |
+
feat4 = self.conv4(input_tensor)
|
| 279 |
+
feat5 = self.conv5(input_tensor)
|
| 280 |
+
|
| 281 |
+
# Concatenate features from all layers. Depending on the architecture, include the extra features.
|
| 282 |
+
if self.nn_architecture in self.six_layer:
|
| 283 |
+
feat6 = self.conv6(input_tensor)
|
| 284 |
+
out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6), dim=1)
|
| 285 |
+
elif self.nn_architecture in self.seven_layer:
|
| 286 |
+
feat6 = self.conv6(input_tensor)
|
| 287 |
+
feat7 = self.conv7(input_tensor)
|
| 288 |
+
out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1)
|
| 289 |
+
else:
|
| 290 |
+
out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
|
| 291 |
+
|
| 292 |
+
# Apply the bottleneck layer to combine and reduce the channel dimensions.
|
| 293 |
+
bottleneck_output = self.bottleneck(out)
|
| 294 |
+
return bottleneck_output
|
audio_separator/separator/uvr_lib_v5/vr_network/layers_new.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from audio_separator.separator.uvr_lib_v5 import spec_utils
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Conv2DBNActiv(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Conv2DBNActiv Class:
|
| 11 |
+
This class implements a convolutional layer followed by batch normalization and an activation function.
|
| 12 |
+
It is a fundamental building block for constructing neural networks, especially useful in image and audio processing tasks.
|
| 13 |
+
The class encapsulates the pattern of applying a convolution, normalizing the output, and then applying a non-linear activation.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
|
| 17 |
+
super(Conv2DBNActiv, self).__init__()
|
| 18 |
+
|
| 19 |
+
# Sequential model combining Conv2D, BatchNorm, and activation function into a single module
|
| 20 |
+
self.conv = nn.Sequential(nn.Conv2d(nin, nout, kernel_size=ksize, stride=stride, padding=pad, dilation=dilation, bias=False), nn.BatchNorm2d(nout), activ())
|
| 21 |
+
|
| 22 |
+
def __call__(self, input_tensor):
|
| 23 |
+
# Forward pass through the sequential model
|
| 24 |
+
return self.conv(input_tensor)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Encoder(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
Encoder Class:
|
| 30 |
+
This class defines an encoder module typically used in autoencoder architectures.
|
| 31 |
+
It consists of two convolutional layers, each followed by batch normalization and an activation function.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
|
| 35 |
+
super(Encoder, self).__init__()
|
| 36 |
+
|
| 37 |
+
# First convolutional layer of the encoder
|
| 38 |
+
self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
|
| 39 |
+
# Second convolutional layer of the encoder
|
| 40 |
+
self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
|
| 41 |
+
|
| 42 |
+
def __call__(self, input_tensor):
|
| 43 |
+
# Applying the first and then the second convolutional layers
|
| 44 |
+
hidden = self.conv1(input_tensor)
|
| 45 |
+
hidden = self.conv2(hidden)
|
| 46 |
+
|
| 47 |
+
return hidden
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Decoder(nn.Module):
|
| 51 |
+
"""
|
| 52 |
+
Decoder Class:
|
| 53 |
+
This class defines a decoder module, which is the counterpart of the Encoder class in autoencoder architectures.
|
| 54 |
+
It applies a convolutional layer followed by batch normalization and an activation function, with an optional dropout layer for regularization.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
|
| 58 |
+
super(Decoder, self).__init__()
|
| 59 |
+
# Convolutional layer with optional dropout for regularization
|
| 60 |
+
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
|
| 61 |
+
# self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
|
| 62 |
+
self.dropout = nn.Dropout2d(0.1) if dropout else None
|
| 63 |
+
|
| 64 |
+
def __call__(self, input_tensor, skip=None):
|
| 65 |
+
# Forward pass through the convolutional layer and optional dropout
|
| 66 |
+
input_tensor = F.interpolate(input_tensor, scale_factor=2, mode="bilinear", align_corners=True)
|
| 67 |
+
|
| 68 |
+
if skip is not None:
|
| 69 |
+
skip = spec_utils.crop_center(skip, input_tensor)
|
| 70 |
+
input_tensor = torch.cat([input_tensor, skip], dim=1)
|
| 71 |
+
|
| 72 |
+
hidden = self.conv1(input_tensor)
|
| 73 |
+
# hidden = self.conv2(hidden)
|
| 74 |
+
|
| 75 |
+
if self.dropout is not None:
|
| 76 |
+
hidden = self.dropout(hidden)
|
| 77 |
+
|
| 78 |
+
return hidden
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ASPPModule(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
ASPPModule Class:
|
| 84 |
+
This class implements the Atrous Spatial Pyramid Pooling (ASPP) module, which is useful for semantic image segmentation tasks.
|
| 85 |
+
It captures multi-scale contextual information by applying convolutions at multiple dilation rates.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
|
| 89 |
+
super(ASPPModule, self).__init__()
|
| 90 |
+
|
| 91 |
+
# Global context convolution captures the overall context
|
| 92 |
+
self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, None)), Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ))
|
| 93 |
+
self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
|
| 94 |
+
self.conv3 = Conv2DBNActiv(nin, nout, 3, 1, dilations[0], dilations[0], activ=activ)
|
| 95 |
+
self.conv4 = Conv2DBNActiv(nin, nout, 3, 1, dilations[1], dilations[1], activ=activ)
|
| 96 |
+
self.conv5 = Conv2DBNActiv(nin, nout, 3, 1, dilations[2], dilations[2], activ=activ)
|
| 97 |
+
self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
|
| 98 |
+
self.dropout = nn.Dropout2d(0.1) if dropout else None
|
| 99 |
+
|
| 100 |
+
def forward(self, input_tensor):
|
| 101 |
+
_, _, h, w = input_tensor.size()
|
| 102 |
+
|
| 103 |
+
# Upsample global context to match input size and combine with local and multi-scale features
|
| 104 |
+
feat1 = F.interpolate(self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True)
|
| 105 |
+
feat2 = self.conv2(input_tensor)
|
| 106 |
+
feat3 = self.conv3(input_tensor)
|
| 107 |
+
feat4 = self.conv4(input_tensor)
|
| 108 |
+
feat5 = self.conv5(input_tensor)
|
| 109 |
+
out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
|
| 110 |
+
out = self.bottleneck(out)
|
| 111 |
+
|
| 112 |
+
if self.dropout is not None:
|
| 113 |
+
out = self.dropout(out)
|
| 114 |
+
|
| 115 |
+
return out
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class LSTMModule(nn.Module):
|
| 119 |
+
"""
|
| 120 |
+
LSTMModule Class:
|
| 121 |
+
This class defines a module that combines convolutional feature extraction with a bidirectional LSTM for sequence modeling.
|
| 122 |
+
It is useful for tasks that require understanding temporal dynamics in data, such as speech and audio processing.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, nin_conv, nin_lstm, nout_lstm):
|
| 126 |
+
super(LSTMModule, self).__init__()
|
| 127 |
+
# Convolutional layer for initial feature extraction
|
| 128 |
+
self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
|
| 129 |
+
|
| 130 |
+
# Bidirectional LSTM for capturing temporal dynamics
|
| 131 |
+
self.lstm = nn.LSTM(input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True)
|
| 132 |
+
|
| 133 |
+
# Dense layer for output dimensionality matching
|
| 134 |
+
self.dense = nn.Sequential(nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU())
|
| 135 |
+
|
| 136 |
+
def forward(self, input_tensor):
|
| 137 |
+
N, _, nbins, nframes = input_tensor.size()
|
| 138 |
+
|
| 139 |
+
# Extract features and prepare for LSTM
|
| 140 |
+
hidden = self.conv(input_tensor)[:, 0] # N, nbins, nframes
|
| 141 |
+
hidden = hidden.permute(2, 0, 1) # nframes, N, nbins
|
| 142 |
+
hidden, _ = self.lstm(hidden)
|
| 143 |
+
|
| 144 |
+
# Apply dense layer and reshape to match expected output format
|
| 145 |
+
hidden = self.dense(hidden.reshape(-1, hidden.size()[-1])) # nframes * N, nbins
|
| 146 |
+
hidden = hidden.reshape(nframes, N, 1, nbins)
|
| 147 |
+
hidden = hidden.permute(1, 2, 3, 0)
|
| 148 |
+
|
| 149 |
+
return hidden
|
audio_separator/separator/uvr_lib_v5/vr_network/model_param_init.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
default_param = {}
|
| 4 |
+
default_param["bins"] = -1
|
| 5 |
+
default_param["unstable_bins"] = -1 # training only
|
| 6 |
+
default_param["stable_bins"] = -1 # training only
|
| 7 |
+
default_param["sr"] = 44100
|
| 8 |
+
default_param["pre_filter_start"] = -1
|
| 9 |
+
default_param["pre_filter_stop"] = -1
|
| 10 |
+
default_param["band"] = {}
|
| 11 |
+
|
| 12 |
+
N_BINS = "n_bins"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def int_keys(d):
|
| 16 |
+
"""
|
| 17 |
+
Converts string keys that represent integers into actual integer keys in a list.
|
| 18 |
+
|
| 19 |
+
This function is particularly useful when dealing with JSON data that may represent
|
| 20 |
+
integer keys as strings due to the nature of JSON encoding. By converting these keys
|
| 21 |
+
back to integers, it ensures that the data can be used in a manner consistent with
|
| 22 |
+
its original representation, especially in contexts where the distinction between
|
| 23 |
+
string and integer keys is important.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
input_list (list of tuples): A list of (key, value) pairs where keys are strings
|
| 27 |
+
that may represent integers.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
dict: A dictionary with keys converted to integers where applicable.
|
| 31 |
+
"""
|
| 32 |
+
# Initialize an empty dictionary to hold the converted key-value pairs.
|
| 33 |
+
result_dict = {}
|
| 34 |
+
# Iterate through each key-value pair in the input list.
|
| 35 |
+
for key, value in d:
|
| 36 |
+
# Check if the key is a digit (i.e., represents an integer).
|
| 37 |
+
if key.isdigit():
|
| 38 |
+
# Convert the key from a string to an integer.
|
| 39 |
+
key = int(key)
|
| 40 |
+
result_dict[key] = value
|
| 41 |
+
return result_dict
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ModelParameters(object):
|
| 45 |
+
"""
|
| 46 |
+
A class to manage model parameters, including loading from a configuration file.
|
| 47 |
+
|
| 48 |
+
Attributes:
|
| 49 |
+
param (dict): Dictionary holding all parameters for the model.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, config_path=""):
|
| 53 |
+
"""
|
| 54 |
+
Initializes the ModelParameters object by loading parameters from a JSON configuration file.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
config_path (str): Path to the JSON configuration file.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
# Load parameters from the given configuration file path.
|
| 61 |
+
with open(config_path, "r") as f:
|
| 62 |
+
self.param = json.loads(f.read(), object_pairs_hook=int_keys)
|
| 63 |
+
|
| 64 |
+
# Ensure certain parameters are set to False if not specified in the configuration.
|
| 65 |
+
for k in ["mid_side", "mid_side_b", "mid_side_b2", "stereo_w", "stereo_n", "reverse"]:
|
| 66 |
+
if not k in self.param:
|
| 67 |
+
self.param[k] = False
|
| 68 |
+
|
| 69 |
+
# If 'n_bins' is specified in the parameters, it's used as the value for 'bins'.
|
| 70 |
+
if N_BINS in self.param:
|
| 71 |
+
self.param["bins"] = self.param[N_BINS]
|
audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr16000_hl512.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bins": 1024,
|
| 3 |
+
"unstable_bins": 0,
|
| 4 |
+
"reduction_bins": 0,
|
| 5 |
+
"band": {
|
| 6 |
+
"1": {
|
| 7 |
+
"sr": 16000,
|
| 8 |
+
"hl": 512,
|
| 9 |
+
"n_fft": 2048,
|
| 10 |
+
"crop_start": 0,
|
| 11 |
+
"crop_stop": 1024,
|
| 12 |
+
"hpf_start": -1,
|
| 13 |
+
"res_type": "sinc_best"
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"sr": 16000,
|
| 17 |
+
"pre_filter_start": 1023,
|
| 18 |
+
"pre_filter_stop": 1024
|
| 19 |
+
}
|
audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr32000_hl512.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bins": 1024,
|
| 3 |
+
"unstable_bins": 0,
|
| 4 |
+
"reduction_bins": 0,
|
| 5 |
+
"band": {
|
| 6 |
+
"1": {
|
| 7 |
+
"sr": 32000,
|
| 8 |
+
"hl": 512,
|
| 9 |
+
"n_fft": 2048,
|
| 10 |
+
"crop_start": 0,
|
| 11 |
+
"crop_stop": 1024,
|
| 12 |
+
"hpf_start": -1,
|
| 13 |
+
"res_type": "kaiser_fast"
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"sr": 32000,
|
| 17 |
+
"pre_filter_start": 1000,
|
| 18 |
+
"pre_filter_stop": 1021
|
| 19 |
+
}
|
audio_separator/separator/uvr_lib_v5/vr_network/modelparams/1band_sr33075_hl384.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bins": 1024,
|
| 3 |
+
"unstable_bins": 0,
|
| 4 |
+
"reduction_bins": 0,
|
| 5 |
+
"band": {
|
| 6 |
+
"1": {
|
| 7 |
+
"sr": 33075,
|
| 8 |
+
"hl": 384,
|
| 9 |
+
"n_fft": 2048,
|
| 10 |
+
"crop_start": 0,
|
| 11 |
+
"crop_stop": 1024,
|
| 12 |
+
"hpf_start": -1,
|
| 13 |
+
"res_type": "sinc_best"
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"sr": 33075,
|
| 17 |
+
"pre_filter_start": 1000,
|
| 18 |
+
"pre_filter_stop": 1021
|
| 19 |
+
}
|