Spaces:
Running
Running
Upload 288 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env +8 -0
- .gitattributes +1 -0
- .gitignore +49 -0
- Applio_(Mangio_RVC_Fork).ipynb +169 -0
- Dockerfile +29 -0
- Fixes/local_fixes.py +136 -0
- Fixes/tensor-launch.py +15 -0
- LICENSE +59 -0
- LazyImport.py +13 -0
- MDX-Net_Colab.ipynb +524 -0
- MDXNet.py +272 -0
- Makefile +63 -0
- README.md +222 -12
- assets/hubert/.gitignore +2 -0
- assets/pretrained/.gitignore +2 -0
- assets/pretrained_v2/.gitignore +2 -0
- assets/rmvpe/.gitignore +2 -0
- assets/uvr5_weights/.gitignore +2 -0
- assets/weights/.gitignore +2 -0
- audioEffects.py +37 -0
- audios/.gitignore +0 -0
- colab_for_mdx.py +71 -0
- configs/32k.json +50 -0
- configs/32k_v2.json +50 -0
- configs/40k.json +50 -0
- configs/48k.json +50 -0
- configs/48k_v2.json +50 -0
- configs/config.json +15 -0
- configs/config.py +265 -0
- configs/v1/32k.json +46 -0
- configs/v1/40k.json +46 -0
- configs/v1/48k.json +46 -0
- configs/v2/32k.json +46 -0
- configs/v2/48k.json +46 -0
- csvdb/formanting.csv +0 -0
- csvdb/stop.csv +0 -0
- demucs/__init__.py +7 -0
- demucs/__main__.py +317 -0
- demucs/audio.py +172 -0
- demucs/augment.py +106 -0
- demucs/compressed.py +115 -0
- demucs/model.py +202 -0
- demucs/parser.py +244 -0
- demucs/pretrained.py +107 -0
- demucs/raw.py +173 -0
- demucs/repitch.py +96 -0
- demucs/separate.py +185 -0
- demucs/tasnet.py +452 -0
- demucs/test.py +109 -0
- demucs/train.py +127 -0
.env
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OPENBLAS_NUM_THREADS = 1
|
2 |
+
no_proxy = localhost, 127.0.0.1, ::1
|
3 |
+
|
4 |
+
# You can change the location of the model, etc. by changing here
|
5 |
+
weight_root = weights
|
6 |
+
weight_uvr5_root = uvr5_weights
|
7 |
+
index_root = logs
|
8 |
+
rmvpe_root = assets/rmvpe
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
stftpitchshift filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
__pycache__
|
3 |
+
/TEMP
|
4 |
+
/DATASETS
|
5 |
+
/RUNTIME
|
6 |
+
*.pyd
|
7 |
+
hubert_base.pt
|
8 |
+
.venv
|
9 |
+
alexforkINSTALL.bat
|
10 |
+
Changelog_CN.md
|
11 |
+
Changelog_EN.md
|
12 |
+
Changelog_KO.md
|
13 |
+
difdep.py
|
14 |
+
EasierGUI.py
|
15 |
+
envfilescheck.bat
|
16 |
+
export_onnx.py
|
17 |
+
.vscode/
|
18 |
+
export_onnx_old.py
|
19 |
+
ffmpeg.exe
|
20 |
+
ffprobe.exe
|
21 |
+
Fixes/Launch_Tensorboard.bat
|
22 |
+
Fixes/LOCAL_CREPE_FIX.bat
|
23 |
+
Fixes/local_fixes.py
|
24 |
+
Fixes/tensor-launch.py
|
25 |
+
gui.py
|
26 |
+
infer-web — backup.py
|
27 |
+
infer-webbackup.py
|
28 |
+
install_easy_dependencies.py
|
29 |
+
install_easyGUI.bat
|
30 |
+
installstft.bat
|
31 |
+
Launch_Tensorboard.bat
|
32 |
+
listdepend.bat
|
33 |
+
LOCAL_CREPE_FIX.bat
|
34 |
+
local_fixes.py
|
35 |
+
oldinfer.py
|
36 |
+
onnx_inference_demo.py
|
37 |
+
Praat.exe
|
38 |
+
requirementsNEW.txt
|
39 |
+
rmvpe.pt
|
40 |
+
rmvpe.onnx
|
41 |
+
run_easiergui.bat
|
42 |
+
tensor-launch.py
|
43 |
+
values1.json
|
44 |
+
使用需遵守的协议-LICENSE.txt
|
45 |
+
!logs/
|
46 |
+
|
47 |
+
logs/*
|
48 |
+
logs/mute/0_gt_wavs/mute40k.spec.pt
|
49 |
+
!logs/mute/
|
Applio_(Mangio_RVC_Fork).ipynb
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {
|
7 |
+
"cellView": "form",
|
8 |
+
"id": "izLwNF_8T1TK"
|
9 |
+
},
|
10 |
+
"outputs": [],
|
11 |
+
"source": [
|
12 |
+
"#@title <font color='#06ae56'>**🍏 Applio (Mangio-RVC-Fork)**</font>\n",
|
13 |
+
"import time\n",
|
14 |
+
"import os\n",
|
15 |
+
"import subprocess\n",
|
16 |
+
"import shutil\n",
|
17 |
+
"import threading\n",
|
18 |
+
"import base64\n",
|
19 |
+
"import threading\n",
|
20 |
+
"import time\n",
|
21 |
+
"from IPython.display import HTML, clear_output\n",
|
22 |
+
"\n",
|
23 |
+
"nosv_name1 = base64.b64decode(('ZXh0ZXJuYWxj').encode('ascii')).decode('ascii')\n",
|
24 |
+
"nosv_name2 = base64.b64decode(('b2xhYmNvZGU=').encode('ascii')).decode('ascii')\n",
|
25 |
+
"guebui = base64.b64decode(('V2U=').encode('ascii')).decode('ascii')\n",
|
26 |
+
"guebui2 = base64.b64decode(('YlVJ').encode('ascii')).decode('ascii')\n",
|
27 |
+
"pbestm = base64.b64decode(('cm12cGU=').encode('ascii')).decode('ascii')\n",
|
28 |
+
"tryre = base64.b64decode(('UmV0cmlldmFs').encode('ascii')).decode('ascii')\n",
|
29 |
+
"\n",
|
30 |
+
"xdsame = '/content/'+ tryre +'-based-Voice-Conversion-' + guebui + guebui2 +'/'\n",
|
31 |
+
"\n",
|
32 |
+
"collapsible_section = \"\"\"\n",
|
33 |
+
"<br>\n",
|
34 |
+
"<br>\n",
|
35 |
+
"<details style=\"border: 1px solid #ddd; border-radius: 5px; padding: 10px; margin-bottom: 10px;\">\n",
|
36 |
+
" <summary open style=\"font-weight: bold; cursor: pointer;\">🚀 Click to learn more about Applio</summary>\n",
|
37 |
+
" <div style=\"margin-left: 20px;\">\n",
|
38 |
+
" <ul>\n",
|
39 |
+
" <li><a href=\"https://github.com/Mangio621/Mangio-RVC-Fork\" style=\"color: #06ae56;\">Mangio-RVC-Fork</a> - Source of inspiration and base for this improved code, special thanks to the developers.</li>\n",
|
40 |
+
" <li><a href=\"https://github.com/Anjok07/ultimatevocalremovergui\" style=\"color: #06ae56;\">UltimateVocalRemover</a> - Used for voice and instrument separation.</li>\n",
|
41 |
+
" <li>Vidal, Blaise & Aitron - Contributors to the Applio version.</li>\n",
|
42 |
+
" <li>kalomaze - Creator of external scripts that help the functioning of Applio.</li>\n",
|
43 |
+
" </ul>\n",
|
44 |
+
" <p style=\"color: #fff;\">Join and contribute to the project on <a href=\"https://github.com/IAHispano/Applio-RVC-Fork\" style=\"color: #06ae56;\">our GitHub repository</a>.</p>\n",
|
45 |
+
" </div>\n",
|
46 |
+
"</details>\n",
|
47 |
+
"<br>\n",
|
48 |
+
"<button style=\"font-weight: bold; cursor: pointer; background-color: #06ae56; color: white; border: 1px solid #fff; border-radius: 4px; padding: 10px 20px; text-decoration: none;\" onclick=\"window.open('https://discord.gg/IAHispano', '_blank')\">🍏 Join our support Discord server (IA Hispano)</button>\n",
|
49 |
+
"<br>\n",
|
50 |
+
"<br>\n",
|
51 |
+
"\"\"\"\n",
|
52 |
+
"#@markdown **Settings:**\n",
|
53 |
+
"ForceUpdateDependencies = True\n",
|
54 |
+
"ForceNoMountDrive = False\n",
|
55 |
+
"#@markdown Restore your backup from Google Drive.\n",
|
56 |
+
"LoadBackupDrive = False #@param{type:\"boolean\"}\n",
|
57 |
+
"#@markdown Make regular backups of your model's training.\n",
|
58 |
+
"AutoBackups = True #@param{type:\"boolean\"}\n",
|
59 |
+
"if not os.path.exists(xdsame):\n",
|
60 |
+
" current_path = os.getcwd()\n",
|
61 |
+
" shutil.rmtree('/content/')\n",
|
62 |
+
" os.makedirs('/content/', exist_ok=True)\n",
|
63 |
+
"\n",
|
64 |
+
" os.chdir(current_path)\n",
|
65 |
+
" !git clone https://github.com/IAHispano/$nosv_name1$nosv_name2 /content/$tryre-based-Voice-Conversion-$guebui$guebui2/utils\n",
|
66 |
+
" clear_output()\n",
|
67 |
+
"\n",
|
68 |
+
" os.chdir(xdsame)\n",
|
69 |
+
" from utils.dependency import *\n",
|
70 |
+
" from utils.clonerepo_experimental import *\n",
|
71 |
+
" os.chdir(\"..\")\n",
|
72 |
+
"\n",
|
73 |
+
"\n",
|
74 |
+
"\n",
|
75 |
+
" setup_environment(ForceUpdateDependencies, ForceNoMountDrive)\n",
|
76 |
+
" clone_repository(True)\n",
|
77 |
+
"\n",
|
78 |
+
" !wget https://huggingface.co/lj1995/VoiceConversion$guebui$guebui2/resolve/main/rmvpe.pt -P /content/Retrieval-based-Voice-Conversion-$guebui$guebui2/\n",
|
79 |
+
" clear_output()\n",
|
80 |
+
"\n",
|
81 |
+
"base_path = \"/content/Retrieval-based-Voice-Conversion-$guebui$guebui2/\"\n",
|
82 |
+
"clear_output()\n",
|
83 |
+
"\n",
|
84 |
+
"\n",
|
85 |
+
"\n",
|
86 |
+
"from utils import backups\n",
|
87 |
+
"\n",
|
88 |
+
"LOGS_FOLDER = xdsame + '/logs'\n",
|
89 |
+
"if not os.path.exists(LOGS_FOLDER):\n",
|
90 |
+
" os.makedirs(LOGS_FOLDER)\n",
|
91 |
+
" clear_output()\n",
|
92 |
+
"\n",
|
93 |
+
"WEIGHTS_FOLDER = xdsame + '/logs' + '/weights'\n",
|
94 |
+
"if not os.path.exists(WEIGHTS_FOLDER):\n",
|
95 |
+
" os.makedirs(WEIGHTS_FOLDER)\n",
|
96 |
+
" clear_output()\n",
|
97 |
+
"\n",
|
98 |
+
"others_FOLDER = xdsame + '/audio-others'\n",
|
99 |
+
"if not os.path.exists(others_FOLDER):\n",
|
100 |
+
" os.makedirs(others_FOLDER)\n",
|
101 |
+
" clear_output()\n",
|
102 |
+
"\n",
|
103 |
+
"audio_outputs_FOLDER = xdsame + '/audio-outputs'\n",
|
104 |
+
"if not os.path.exists(audio_outputs_FOLDER):\n",
|
105 |
+
" os.makedirs(audio_outputs_FOLDER)\n",
|
106 |
+
" clear_output()\n",
|
107 |
+
"\n",
|
108 |
+
"if LoadBackupDrive:\n",
|
109 |
+
" backups.import_google_drive_backup()\n",
|
110 |
+
" clear_output()\n",
|
111 |
+
"\n",
|
112 |
+
"#@markdown Choose the language in which you want the interface to be available.\n",
|
113 |
+
"i18n_path = xdsame + 'i18n.py'\n",
|
114 |
+
"i18n_new_path = xdsame + 'utils/i18n.py'\n",
|
115 |
+
"try:\n",
|
116 |
+
" if os.path.exists(i18n_path) and os.path.exists(i18n_new_path):\n",
|
117 |
+
" shutil.move(i18n_new_path, i18n_path)\n",
|
118 |
+
"\n",
|
119 |
+
" SelectedLanguage = \"en_US\" #@param [\"es_ES\", \"en_US\", \"zh_CN\", \"ar_AR\", \"id_ID\", \"pt_PT\", \"ru_RU\", \"ur_UR\", \"tr_TR\", \"it_IT\", \"de_DE\"]\n",
|
120 |
+
" new_language_line = ' language = \"' + SelectedLanguage + '\"\\n'\n",
|
121 |
+
"#@markdown <a href=\"https://discord.gg/iahispano\"><font>If you need more help, feel free to join our official Discord server!</font></a>\n",
|
122 |
+
" with open(i18n_path, 'r') as file:\n",
|
123 |
+
" lines = file.readlines()\n",
|
124 |
+
"\n",
|
125 |
+
" with open(i18n_path, 'w') as file:\n",
|
126 |
+
" for index, line in enumerate(lines):\n",
|
127 |
+
" if index == 14:\n",
|
128 |
+
" file.write(new_language_line)\n",
|
129 |
+
" else:\n",
|
130 |
+
" file.write(line)\n",
|
131 |
+
"\n",
|
132 |
+
"except FileNotFoundError:\n",
|
133 |
+
" print(\"Translation couldn't be applied successfully. Please restart the environment and run the cell again.\")\n",
|
134 |
+
"\n",
|
135 |
+
"def start_web_server():\n",
|
136 |
+
" %cd /content/$tryre-based-Voice-Conversion-$guebui$guebui2\n",
|
137 |
+
" %load_ext tensorboard\n",
|
138 |
+
" clear_output()\n",
|
139 |
+
" %tensorboard --logdir /content/$tryre-based-Voice-Conversion-$guebui$guebui2/logs\n",
|
140 |
+
" !mkdir -p /content/$tryre-based-Voice-Conversion-$guebui$guebui2/audios\n",
|
141 |
+
" display(HTML(collapsible_section))\n",
|
142 |
+
" !python3 infer-web.py --colab --pycmd python3\n",
|
143 |
+
"\n",
|
144 |
+
"if AutoBackups:\n",
|
145 |
+
" web_server_thread = threading.Thread(target=start_web_server)\n",
|
146 |
+
" web_server_thread.start()\n",
|
147 |
+
" backups.backup_files()\n",
|
148 |
+
"\n",
|
149 |
+
"else:\n",
|
150 |
+
" start_web_server()"
|
151 |
+
]
|
152 |
+
}
|
153 |
+
],
|
154 |
+
"metadata": {
|
155 |
+
"accelerator": "GPU",
|
156 |
+
"colab": {
|
157 |
+
"provenance": []
|
158 |
+
},
|
159 |
+
"kernelspec": {
|
160 |
+
"display_name": "Python 3",
|
161 |
+
"name": "python3"
|
162 |
+
},
|
163 |
+
"language_info": {
|
164 |
+
"name": "python"
|
165 |
+
}
|
166 |
+
},
|
167 |
+
"nbformat": 4,
|
168 |
+
"nbformat_minor": 0
|
169 |
+
}
|
Dockerfile
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# syntax=docker/dockerfile:1
|
2 |
+
|
3 |
+
FROM python:3.10-bullseye
|
4 |
+
|
5 |
+
EXPOSE 7865
|
6 |
+
|
7 |
+
WORKDIR /app
|
8 |
+
|
9 |
+
COPY . .
|
10 |
+
|
11 |
+
RUN apt update && apt install -y -qq ffmpeg aria2 && apt clean
|
12 |
+
|
13 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
14 |
+
|
15 |
+
RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/D40k.pth -d assets/pretrained_v2/ -o D40k.pth
|
16 |
+
RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/G40k.pth -d assets/pretrained_v2/ -o G40k.pth
|
17 |
+
RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0D40k.pth -d assets/pretrained_v2/ -o f0D40k.pth
|
18 |
+
RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0G40k.pth -d assets/pretrained_v2/ -o f0G40k.pth
|
19 |
+
|
20 |
+
RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2-人声vocals+非人声instrumentals.pth -d assets/uvr5_weights/ -o HP2-人声vocals+非人声instrumentals.pth
|
21 |
+
RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5-主旋律人声vocals+其他instrumentals.pth -d assets/uvr5_weights/ -o HP5-主旋律人声vocals+其他instrumentals.pth
|
22 |
+
|
23 |
+
RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt -d assets/hubert -o hubert_base.pt
|
24 |
+
|
25 |
+
RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/rmvpe.pt -d assets/hubert -o rmvpe.pt
|
26 |
+
|
27 |
+
VOLUME [ "/app/weights", "/app/opt" ]
|
28 |
+
|
29 |
+
CMD ["python3", "infer-web.py"]
|
Fixes/local_fixes.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import shutil
|
5 |
+
import requests
|
6 |
+
import zipfile
|
7 |
+
|
8 |
+
def insert_new_line(file_name, line_to_find, text_to_insert):
|
9 |
+
lines = []
|
10 |
+
with open(file_name, 'r', encoding='utf-8') as read_obj:
|
11 |
+
lines = read_obj.readlines()
|
12 |
+
already_exists = False
|
13 |
+
with open(file_name + '.tmp', 'w', encoding='utf-8') as write_obj:
|
14 |
+
for i in range(len(lines)):
|
15 |
+
write_obj.write(lines[i])
|
16 |
+
if lines[i].strip() == line_to_find:
|
17 |
+
# If next line exists and starts with sys.path.append, skip
|
18 |
+
if i+1 < len(lines) and lines[i+1].strip().startswith("sys.path.append"):
|
19 |
+
print('It was already fixed! Skip adding a line...')
|
20 |
+
already_exists = True
|
21 |
+
break
|
22 |
+
else:
|
23 |
+
write_obj.write(text_to_insert + '\n')
|
24 |
+
# If no existing sys.path.append line was found, replace the original file
|
25 |
+
if not already_exists:
|
26 |
+
os.replace(file_name + '.tmp', file_name)
|
27 |
+
return True
|
28 |
+
else:
|
29 |
+
# If existing line was found, delete temporary file
|
30 |
+
os.remove(file_name + '.tmp')
|
31 |
+
return False
|
32 |
+
|
33 |
+
def replace_in_file(file_name, old_text, new_text):
|
34 |
+
with open(file_name, 'r', encoding='utf-8') as file:
|
35 |
+
file_contents = file.read()
|
36 |
+
|
37 |
+
if old_text in file_contents:
|
38 |
+
file_contents = file_contents.replace(old_text, new_text)
|
39 |
+
with open(file_name, 'w', encoding='utf-8') as file:
|
40 |
+
file.write(file_contents)
|
41 |
+
return True
|
42 |
+
|
43 |
+
return False
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
current_path = os.getcwd()
|
47 |
+
file_name = os.path.join(current_path, "infer", "modules", "train", "extract", "extract_f0_print.py")
|
48 |
+
line_to_find = 'import numpy as np, logging'
|
49 |
+
text_to_insert = "sys.path.append(r'" + current_path + "')"
|
50 |
+
|
51 |
+
|
52 |
+
success_1 = insert_new_line(file_name, line_to_find, text_to_insert)
|
53 |
+
if success_1:
|
54 |
+
print('The first operation was successful!')
|
55 |
+
else:
|
56 |
+
print('He skipped the first operation because it was already fixed!')
|
57 |
+
|
58 |
+
file_name = 'infer-web.py'
|
59 |
+
old_text = 'with gr.Blocks(theme=gr.themes.Soft()) as app:'
|
60 |
+
new_text = 'with gr.Blocks() as app:'
|
61 |
+
|
62 |
+
success_2 = replace_in_file(file_name, old_text, new_text)
|
63 |
+
if success_2:
|
64 |
+
print('The second operation was successful!')
|
65 |
+
else:
|
66 |
+
print('The second operation was omitted because it was already fixed!')
|
67 |
+
|
68 |
+
print('Local corrections successful! You should now be able to infer and train locally in Applio RVC Fork.')
|
69 |
+
|
70 |
+
time.sleep(5)
|
71 |
+
|
72 |
+
def find_torchcrepe_directory(directory):
|
73 |
+
"""
|
74 |
+
Recursively searches for the topmost folder named 'torchcrepe' within a directory.
|
75 |
+
Returns the path of the directory found or None if none is found.
|
76 |
+
"""
|
77 |
+
for root, dirs, files in os.walk(directory):
|
78 |
+
if 'torchcrepe' in dirs:
|
79 |
+
return os.path.join(root, 'torchcrepe')
|
80 |
+
return None
|
81 |
+
|
82 |
+
def download_and_extract_torchcrepe():
|
83 |
+
url = 'https://github.com/maxrmorrison/torchcrepe/archive/refs/heads/master.zip'
|
84 |
+
temp_dir = 'temp_torchcrepe'
|
85 |
+
destination_dir = os.getcwd()
|
86 |
+
|
87 |
+
try:
|
88 |
+
torchcrepe_dir_path = os.path.join(destination_dir, 'torchcrepe')
|
89 |
+
|
90 |
+
if os.path.exists(torchcrepe_dir_path):
|
91 |
+
print("Skipping the torchcrepe download. The folder already exists.")
|
92 |
+
return
|
93 |
+
|
94 |
+
# Download the file
|
95 |
+
print("Starting torchcrepe download...")
|
96 |
+
response = requests.get(url)
|
97 |
+
|
98 |
+
# Raise an error if the GET request was unsuccessful
|
99 |
+
response.raise_for_status()
|
100 |
+
print("Download completed.")
|
101 |
+
|
102 |
+
# Save the downloaded file
|
103 |
+
zip_file_path = os.path.join(temp_dir, 'master.zip')
|
104 |
+
os.makedirs(temp_dir, exist_ok=True)
|
105 |
+
with open(zip_file_path, 'wb') as file:
|
106 |
+
file.write(response.content)
|
107 |
+
print(f"Zip file saved to {zip_file_path}")
|
108 |
+
|
109 |
+
# Extract the zip file
|
110 |
+
print("Extracting content...")
|
111 |
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
|
112 |
+
zip_file.extractall(temp_dir)
|
113 |
+
print("Extraction completed.")
|
114 |
+
|
115 |
+
# Locate the torchcrepe folder and move it to the destination directory
|
116 |
+
torchcrepe_dir = find_torchcrepe_directory(temp_dir)
|
117 |
+
if torchcrepe_dir:
|
118 |
+
shutil.move(torchcrepe_dir, destination_dir)
|
119 |
+
print(f"Moved the torchcrepe directory to {destination_dir}!")
|
120 |
+
else:
|
121 |
+
print("The torchcrepe directory could not be located.")
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
print("Torchcrepe not successfully downloaded", e)
|
125 |
+
|
126 |
+
# Clean up temporary directory
|
127 |
+
if os.path.exists(temp_dir):
|
128 |
+
shutil.rmtree(temp_dir)
|
129 |
+
|
130 |
+
# Run the function
|
131 |
+
download_and_extract_torchcrepe()
|
132 |
+
|
133 |
+
temp_dir = 'temp_torchcrepe'
|
134 |
+
|
135 |
+
if os.path.exists(temp_dir):
|
136 |
+
shutil.rmtree(temp_dir)
|
Fixes/tensor-launch.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
import time
|
3 |
+
from tensorboard import program
|
4 |
+
import os
|
5 |
+
|
6 |
+
log_path = "logs"
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
tb = program.TensorBoard()
|
10 |
+
tb.configure(argv=[None, '--logdir', log_path])
|
11 |
+
url = tb.launch()
|
12 |
+
print(f'Tensorboard can be accessed at: {url}')
|
13 |
+
|
14 |
+
while True:
|
15 |
+
time.sleep(600) # Keep the main thread running
|
LICENSE
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 liujing04
|
4 |
+
Copyright (c) 2023 源文雨
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
of this software and associated documentation files (the "Software"), to deal
|
8 |
+
in the Software without restriction, including without limitation the rights
|
9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
copies of the Software, and to permit persons to whom the Software is
|
11 |
+
furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
SOFTWARE.
|
23 |
+
|
24 |
+
The licenses for related libraries are as follows:
|
25 |
+
|
26 |
+
ContentVec
|
27 |
+
https://github.com/auspicious3000/contentvec/blob/main/LICENSE
|
28 |
+
MIT License
|
29 |
+
|
30 |
+
VITS
|
31 |
+
https://github.com/jaywalnut310/vits/blob/main/LICENSE
|
32 |
+
MIT License
|
33 |
+
|
34 |
+
HIFIGAN
|
35 |
+
https://github.com/jik876/hifi-gan/blob/master/LICENSE
|
36 |
+
MIT License
|
37 |
+
|
38 |
+
gradio
|
39 |
+
https://github.com/gradio-app/gradio/blob/main/LICENSE
|
40 |
+
Apache License 2.0
|
41 |
+
|
42 |
+
ffmpeg
|
43 |
+
https://github.com/FFmpeg/FFmpeg/blob/master/COPYING.LGPLv3
|
44 |
+
https://github.com/BtbN/FFmpeg-Builds/releases/download/autobuild-2021-02-28-12-32/ffmpeg-n4.3.2-160-gfbb9368226-win64-lgpl-4.3.zip
|
45 |
+
LPGLv3 License
|
46 |
+
MIT License
|
47 |
+
|
48 |
+
ultimatevocalremovergui
|
49 |
+
https://github.com/Anjok07/ultimatevocalremovergui/blob/master/LICENSE
|
50 |
+
https://github.com/yang123qwe/vocal_separation_by_uvr5
|
51 |
+
MIT License
|
52 |
+
|
53 |
+
audio-slicer
|
54 |
+
https://github.com/openvpi/audio-slicer/blob/main/LICENSE
|
55 |
+
MIT License
|
56 |
+
|
57 |
+
PySimpleGUI
|
58 |
+
https://github.com/PySimpleGUI/PySimpleGUI/blob/master/license.txt
|
59 |
+
LPGLv3 License
|
LazyImport.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib.util import find_spec, LazyLoader, module_from_spec
|
2 |
+
from sys import modules
|
3 |
+
|
4 |
+
def lazyload(name):
|
5 |
+
if name in modules:
|
6 |
+
return modules[name]
|
7 |
+
else:
|
8 |
+
spec = find_spec(name)
|
9 |
+
loader = LazyLoader(spec.loader)
|
10 |
+
module = module_from_spec(spec)
|
11 |
+
modules[name] = module
|
12 |
+
loader.exec_module(module)
|
13 |
+
return module
|
MDX-Net_Colab.ipynb
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "wX9xzLur4tus"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"# MDX-Net Colab\n",
|
10 |
+
"<div style=\"display:flex; align-items:center; font-size: 16px;\">\n",
|
11 |
+
" <img src=\"https://github.githubassets.com/pinned-octocat.svg\" alt=\"icon1\" style=\"margin-right:10px; height: 20px;\" width=\"1.5%\">\n",
|
12 |
+
" <span>Trained models provided in this notebook are from <a href=\"https://github.com/Anjok07\">UVR-GUI</a>.</span>\n",
|
13 |
+
"</div>\n",
|
14 |
+
"<div style=\"display:flex; align-items:center; font-size: 16px;\">\n",
|
15 |
+
" <img src=\"https://github.com/Anjok07/ultimatevocalremovergui/raw/master/gui_data/img/GUI-Icon.ico\" alt=\"icon2\" style=\"margin-right:10px; height: 20px;margin-top:10px\" width=\"1.5%\">\n",
|
16 |
+
" <span>OFFICIAL UVR GITHUB PAGE: <a href=\"https://github.com/Anjok07/ultimatevocalremovergui\">here</a>.</span>\n",
|
17 |
+
"</div>\n",
|
18 |
+
"<div style=\"display:flex; align-items:center; font-size: 16px;\">\n",
|
19 |
+
" <img src=\"https://avatars.githubusercontent.com/u/24620594\" alt=\"icon3\" style=\"margin-right:10px; height: 20px;\" width=\"1.5%\">\n",
|
20 |
+
" <span>OFFICIAL CLI Version: <a href=\"https://github.com/tsurumeso/vocal-remover\">here</a>.</span>\n",
|
21 |
+
"</div>\n",
|
22 |
+
"<div style=\"display:flex; align-items:center; font-size: 16px;\">\n",
|
23 |
+
" <img src=\"https://icons.getbootstrap.com/assets/icons/discord.svg\" alt=\"icon4\" style=\"margin-right:10px; height: 20px;\" width=\"1.5%\">\n",
|
24 |
+
" <span>Join our <a href=\"https://cutt.ly/0TcDjmo\">Discord server</a>!</span>\n",
|
25 |
+
"</div>\n",
|
26 |
+
"<sup><br>Ultimate Vocal Remover (unofficial)</sup>\n",
|
27 |
+
"<sup><br>MDX-Net by <a href=\"https://github.com/kuielab\">kuielab</a> and adapted for Colaboratory by <a href=\"https://www.youtube.com/channel/UC0NiSV1jLMH-9E09wiDVFYw\">AudioHacker</a>.</sup>\n",
|
28 |
+
"\n",
|
29 |
+
"<sup><br>Your support means a lot to me. If you enjoy my work, please consider buying me a ko-fi:<br></sup>\n",
|
30 |
+
"[![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/X8X6M8FR0)"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": null,
|
36 |
+
"metadata": {
|
37 |
+
"id": "3J69RV7G8ocb",
|
38 |
+
"cellView": "form"
|
39 |
+
},
|
40 |
+
"outputs": [],
|
41 |
+
"source": [
|
42 |
+
"import json\n",
|
43 |
+
"import os\n",
|
44 |
+
"import os.path\n",
|
45 |
+
"import gc\n",
|
46 |
+
"import psutil\n",
|
47 |
+
"import requests\n",
|
48 |
+
"import subprocess\n",
|
49 |
+
"import glob\n",
|
50 |
+
"import time\n",
|
51 |
+
"import logging\n",
|
52 |
+
"import sys\n",
|
53 |
+
"from bs4 import BeautifulSoup\n",
|
54 |
+
"from google.colab import drive, files, output\n",
|
55 |
+
"from IPython.display import Audio, display\n",
|
56 |
+
"\n",
|
57 |
+
"if \"first_cell_ran\" in locals():\n",
|
58 |
+
" print(\"You've ran this cell for this session. No need to run it again.\\nif you think something went wrong or you want to change mounting path, restart the runtime.\")\n",
|
59 |
+
"else:\n",
|
60 |
+
" print('Setting up... please wait around 1-2 minute(s).')\n",
|
61 |
+
"\n",
|
62 |
+
" branch = \"https://github.com/NaJeongMo/Colab-for-MDX_B\"\n",
|
63 |
+
"\n",
|
64 |
+
" model_params = \"https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data.json\"\n",
|
65 |
+
" _Models = \"https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/\"\n",
|
66 |
+
" # _models = \"https://pastebin.com/raw/jBzYB8vz\"\n",
|
67 |
+
" _models = \"https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json\"\n",
|
68 |
+
" stem_naming = \"https://pastebin.com/raw/mpH4hRcF\"\n",
|
69 |
+
" arl_check_endpoint = 'https://dz.doubledouble.top/check' # param: arl?=<>\n",
|
70 |
+
"\n",
|
71 |
+
" file_folder = \"Colab-for-MDX_B\"\n",
|
72 |
+
"\n",
|
73 |
+
" model_ids = requests.get(_models).json()\n",
|
74 |
+
" model_ids = model_ids[\"mdx_download_list\"].values()\n",
|
75 |
+
"\n",
|
76 |
+
" model_params = requests.get(model_params).json()\n",
|
77 |
+
" stem_naming = requests.get(stem_naming).json()\n",
|
78 |
+
"\n",
|
79 |
+
" os.makedirs(\"tmp_models\", exist_ok=True)\n",
|
80 |
+
"\n",
|
81 |
+
" # @markdown If you don't wish to mount google drive, uncheck this box.\n",
|
82 |
+
" MountDrive = True # @param{type:\"boolean\"}\n",
|
83 |
+
" # @markdown The path for the drive to be mounted: Please be cautious when modifying this as it can cause issues if not done properly.\n",
|
84 |
+
" mounting_path = \"/content/drive/MyDrive\" # @param [\"snippets:\",\"/content/drive/MyDrive\",\"/content/drive/Shareddrives/<your shared drive name>\", \"/content/drive/Shareddrives/Shared Drive\"]{allow-input: true}\n",
|
85 |
+
" # @markdown Force update and disregard local changes: discards all local modifications in your repository, effectively replacing all files with the versions from the original commit.\n",
|
86 |
+
" force_update = False # @param{type:\"boolean\"}\n",
|
87 |
+
" # @markdown Auto Update (does not discard your changes)\n",
|
88 |
+
" auto_update = True # @param{type:\"boolean\"}\n",
|
89 |
+
"\n",
|
90 |
+
"\n",
|
91 |
+
" reqs_apt = [] # !sudo apt-get install\n",
|
92 |
+
" reqs_pip = [\"librosa>=0.6.3,<0.9\", \"onnxruntime_gpu\", \"deemix\", \"yt_dlp\"] # pip3 install\n",
|
93 |
+
"\n",
|
94 |
+
" class hide_opt: # hide outputs\n",
|
95 |
+
" def __enter__(self):\n",
|
96 |
+
" self._original_stdout = sys.stdout\n",
|
97 |
+
" sys.stdout = open(os.devnull, \"w\")\n",
|
98 |
+
"\n",
|
99 |
+
" def __exit__(self, exc_type, exc_val, exc_tb):\n",
|
100 |
+
" sys.stdout.close()\n",
|
101 |
+
" sys.stdout = self._original_stdout\n",
|
102 |
+
"\n",
|
103 |
+
" def get_size(bytes, suffix=\"B\"): # read ram\n",
|
104 |
+
" global svmem\n",
|
105 |
+
" factor = 1024\n",
|
106 |
+
" for unit in [\"\", \"K\", \"M\", \"G\", \"T\", \"P\"]:\n",
|
107 |
+
" if bytes < factor:\n",
|
108 |
+
" return f\"{bytes:.2f}{unit}{suffix}\"\n",
|
109 |
+
" bytes /= factor\n",
|
110 |
+
" svmem = psutil.virtual_memory()\n",
|
111 |
+
"\n",
|
112 |
+
"\n",
|
113 |
+
" print('installing requirements...',end=' ')\n",
|
114 |
+
" with hide_opt():\n",
|
115 |
+
" for x in reqs_apt:\n",
|
116 |
+
" subprocess.run([\"sudo\", \"apt-get\", \"install\", x])\n",
|
117 |
+
" for x in reqs_pip:\n",
|
118 |
+
" subprocess.run([\"python3\", \"-m\", \"pip\", \"install\", x])\n",
|
119 |
+
" print('done')\n",
|
120 |
+
"\n",
|
121 |
+
" def install_or_mount_drive():\n",
|
122 |
+
" print(\n",
|
123 |
+
" \"Please log in to your account by following the prompts in the pop-up tab.\\nThis step is necessary to install the files to your Google Drive.\\nIf you have any concerns about the safety of this notebook, you can choose not to mount your drive by unchecking the \\\"MountDrive\\\" checkbox.\"\n",
|
124 |
+
" )\n",
|
125 |
+
" drive.mount(\"/content/drive\", force_remount=True)\n",
|
126 |
+
" os.chdir(mounting_path)\n",
|
127 |
+
" # check if previous installation is done\n",
|
128 |
+
" if os.path.exists(os.path.join(mounting_path, file_folder)):\n",
|
129 |
+
" # update checking\n",
|
130 |
+
" os.chdir(file_folder)\n",
|
131 |
+
"\n",
|
132 |
+
" if force_update:\n",
|
133 |
+
" print('Force updating...')\n",
|
134 |
+
"\n",
|
135 |
+
" commands = [\n",
|
136 |
+
" [\"git\", \"pull\"],\n",
|
137 |
+
" [\"git\", \"checkout\", \"--\", \".\"],\n",
|
138 |
+
" ]\n",
|
139 |
+
"\n",
|
140 |
+
" for cmd in commands:\n",
|
141 |
+
" subprocess.run(cmd)\n",
|
142 |
+
"\n",
|
143 |
+
" elif auto_update:\n",
|
144 |
+
" print('Checking for updates...')\n",
|
145 |
+
" commands = [\n",
|
146 |
+
" [\"git\", \"pull\"],\n",
|
147 |
+
" ]\n",
|
148 |
+
"\n",
|
149 |
+
" for cmd in commands:\n",
|
150 |
+
" subprocess.run(cmd)\n",
|
151 |
+
" else:\n",
|
152 |
+
" subprocess.run([\"git\", \"clone\", \"https://github.com/NaJeongMo/Colab-for-MDX_B.git\"])\n",
|
153 |
+
" os.chdir(file_folder)\n",
|
154 |
+
"\n",
|
155 |
+
" def use_uvr_without_saving():\n",
|
156 |
+
" global mounting_path\n",
|
157 |
+
" print(\"Notice: files won't be saved to personal drive.\")\n",
|
158 |
+
" print(f\"Downloading {file_folder}...\", end=\" \")\n",
|
159 |
+
" mounting_path = \"/content\"\n",
|
160 |
+
" with hide_opt():\n",
|
161 |
+
" os.chdir(mounting_path)\n",
|
162 |
+
" subprocess.run([\"git\", \"clone\", \"https://github.com/NaJeongMo/Colab-for-MDX_B.git\"])\n",
|
163 |
+
" os.chdir(file_folder)\n",
|
164 |
+
"\n",
|
165 |
+
" if MountDrive:\n",
|
166 |
+
" install_or_mount_drive()\n",
|
167 |
+
" else:\n",
|
168 |
+
" use_uvr_without_saving()\n",
|
169 |
+
" print(\"done!\")\n",
|
170 |
+
" if not os.path.exists(\"tracks\"):\n",
|
171 |
+
" os.mkdir(\"tracks\")\n",
|
172 |
+
"\n",
|
173 |
+
" print('Importing required libraries...',end=' ')\n",
|
174 |
+
"\n",
|
175 |
+
" import os\n",
|
176 |
+
" import mdx\n",
|
177 |
+
" import librosa\n",
|
178 |
+
" import torch\n",
|
179 |
+
" import soundfile as sf\n",
|
180 |
+
" import numpy as np\n",
|
181 |
+
" import yt_dlp\n",
|
182 |
+
"\n",
|
183 |
+
" from deezer import Deezer\n",
|
184 |
+
" from deezer import TrackFormats\n",
|
185 |
+
" import deemix\n",
|
186 |
+
" from deemix.settings import load as loadSettings\n",
|
187 |
+
" from deemix.downloader import Downloader\n",
|
188 |
+
" from deemix import generateDownloadObject\n",
|
189 |
+
"\n",
|
190 |
+
" logger = logging.getLogger(\"yt_dlp\")\n",
|
191 |
+
" logger.setLevel(logging.ERROR)\n",
|
192 |
+
"\n",
|
193 |
+
" def id_to_ptm(mkey):\n",
|
194 |
+
" if mkey in model_ids:\n",
|
195 |
+
" mpath = f\"/content/tmp_models/{mkey}\"\n",
|
196 |
+
" if not os.path.exists(f'/content/tmp_models/{mkey}'):\n",
|
197 |
+
" print('Downloading model...',end=' ')\n",
|
198 |
+
" subprocess.run(\n",
|
199 |
+
" [\"wget\", _Models+mkey, \"-O\", mpath]\n",
|
200 |
+
" )\n",
|
201 |
+
" print(f'saved to {mpath}')\n",
|
202 |
+
" # get_ipython().system(f'gdown {model_id} -O /content/tmp_models/{mkey}')\n",
|
203 |
+
" return mpath\n",
|
204 |
+
" else:\n",
|
205 |
+
" return mpath\n",
|
206 |
+
" else:\n",
|
207 |
+
" mpath = f'models/{mkey}'\n",
|
208 |
+
" return mpath\n",
|
209 |
+
"\n",
|
210 |
+
" def prepare_mdx(custom_param=False, dim_f=None, dim_t=None, n_fft=None, stem_name=None, compensation=None):\n",
|
211 |
+
" device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n",
|
212 |
+
" if custom_param:\n",
|
213 |
+
" assert not (dim_f is None or dim_t is None or n_fft is None or compensation is None), 'Custom parameter selected, but incomplete parameters are provided.'\n",
|
214 |
+
" mdx_model = mdx.MDX_Model(\n",
|
215 |
+
" device,\n",
|
216 |
+
" dim_f = dim_f,\n",
|
217 |
+
" dim_t = dim_t,\n",
|
218 |
+
" n_fft = n_fft,\n",
|
219 |
+
" stem_name=stem_name,\n",
|
220 |
+
" compensation=compensation\n",
|
221 |
+
" )\n",
|
222 |
+
" else:\n",
|
223 |
+
" model_hash = mdx.MDX.get_hash(onnx)\n",
|
224 |
+
" if model_hash in model_params:\n",
|
225 |
+
" mp = model_params.get(model_hash)\n",
|
226 |
+
" mdx_model = mdx.MDX_Model(\n",
|
227 |
+
" device,\n",
|
228 |
+
" dim_f = mp[\"mdx_dim_f_set\"],\n",
|
229 |
+
" dim_t = 2**mp[\"mdx_dim_t_set\"],\n",
|
230 |
+
" n_fft = mp[\"mdx_n_fft_scale_set\"],\n",
|
231 |
+
" stem_name=mp[\"primary_stem\"],\n",
|
232 |
+
" compensation=compensation if not custom_param and compensation is not None else mp[\"compensate\"]\n",
|
233 |
+
" )\n",
|
234 |
+
" return mdx_model\n",
|
235 |
+
"\n",
|
236 |
+
" def run_mdx(onnx, mdx_model,filename,diff=False,suffix=None,diff_suffix=None, denoise=False, m_threads=1):\n",
|
237 |
+
" mdx_sess = mdx.MDX(onnx,mdx_model)\n",
|
238 |
+
" print(f\"Processing: {filename}\")\n",
|
239 |
+
" wave, sr = librosa.load(filename,mono=False, sr=44100)\n",
|
240 |
+
" # normalizing input wave gives better output\n",
|
241 |
+
" peak = max(np.max(wave), abs(np.min(wave)))\n",
|
242 |
+
" wave /= peak\n",
|
243 |
+
" if denoise:\n",
|
244 |
+
" wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))\n",
|
245 |
+
" wave_processed *= 0.5\n",
|
246 |
+
" else:\n",
|
247 |
+
" wave_processed = mdx_sess.process_wave(wave, m_threads)\n",
|
248 |
+
" # return to previous peak\n",
|
249 |
+
" wave_processed *= peak\n",
|
250 |
+
"\n",
|
251 |
+
" stem_name = mdx_model.stem_name if suffix is None else suffix # use suffix if provided\n",
|
252 |
+
" save_path = f\"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav\"\n",
|
253 |
+
" save_path = os.path.join(\n",
|
254 |
+
" 'separated',\n",
|
255 |
+
" save_path\n",
|
256 |
+
" )\n",
|
257 |
+
" sf.write(\n",
|
258 |
+
" save_path,\n",
|
259 |
+
" wave_processed.T,\n",
|
260 |
+
" sr\n",
|
261 |
+
" )\n",
|
262 |
+
"\n",
|
263 |
+
" print(f'done, saved to: {save_path}')\n",
|
264 |
+
"\n",
|
265 |
+
" if diff:\n",
|
266 |
+
" diff_stem_name = stem_naming.get(stem_name) if diff_suffix is None else diff_suffix # use suffix if provided\n",
|
267 |
+
" stem_name = f\"{stem_name}_diff\" if diff_stem_name is None else diff_stem_name\n",
|
268 |
+
" save_path = f\"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav\"\n",
|
269 |
+
" save_path = os.path.join(\n",
|
270 |
+
" 'separated',\n",
|
271 |
+
" save_path\n",
|
272 |
+
" )\n",
|
273 |
+
" sf.write(\n",
|
274 |
+
" save_path,\n",
|
275 |
+
" (-wave_processed.T*mdx_model.compensation)+wave.T,\n",
|
276 |
+
" sr\n",
|
277 |
+
" )\n",
|
278 |
+
" print(f'invert done, saved to: {save_path}')\n",
|
279 |
+
" del mdx_sess, wave_processed, wave\n",
|
280 |
+
" gc.collect()\n",
|
281 |
+
"\n",
|
282 |
+
" def is_valid_url(url):\n",
|
283 |
+
" import re\n",
|
284 |
+
" regex = re.compile(\n",
|
285 |
+
" r'^https?://'\n",
|
286 |
+
" r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+[A-Z]{2,6}\\.?|'\n",
|
287 |
+
" r'localhost|'\n",
|
288 |
+
" r'\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3})'\n",
|
289 |
+
" r'(?::\\d+)?'\n",
|
290 |
+
" r'(?:/?|[/?]\\S+)$', re.IGNORECASE)\n",
|
291 |
+
" return url is not None and regex.search(url)\n",
|
292 |
+
"\n",
|
293 |
+
" def download_deezer(link, arl, fmt='FLAC'):\n",
|
294 |
+
" match fmt:\n",
|
295 |
+
" case 'FLAC':\n",
|
296 |
+
" bitrate = TrackFormats.FLAC\n",
|
297 |
+
" case 'MP3_320':\n",
|
298 |
+
" bitrate = TrackFormats.MP3_320\n",
|
299 |
+
" case 'MP3_128':\n",
|
300 |
+
" bitrate = TrackFormats.MP3_128\n",
|
301 |
+
" case _:\n",
|
302 |
+
" bitrate = TrackFormats.MP3_128\n",
|
303 |
+
"\n",
|
304 |
+
" dz = Deezer()\n",
|
305 |
+
" settings = loadSettings('dz_config')\n",
|
306 |
+
" settings['downloadLocation'] = './tracks'\n",
|
307 |
+
" if not dz.login_via_arl(arl.strip()):\n",
|
308 |
+
" raise Exception('Error while logging in with provided ARL.')\n",
|
309 |
+
" downloadObject = generateDownloadObject(dz, link, bitrate)\n",
|
310 |
+
" print(f'Downloading {downloadObject.type}: \"{downloadObject.title}\" by {downloadObject.artist}...',end=' ',flush=True)\n",
|
311 |
+
" Downloader(dz, downloadObject, settings).start()\n",
|
312 |
+
" print(f'done.')\n",
|
313 |
+
"\n",
|
314 |
+
" path_to_audio = []\n",
|
315 |
+
" for file in downloadObject.files:\n",
|
316 |
+
" path_to_audio.append(file[\"path\"])\n",
|
317 |
+
"\n",
|
318 |
+
" return path_to_audio\n",
|
319 |
+
"\n",
|
320 |
+
" def download_link(url):\n",
|
321 |
+
" ydl_opts = {\n",
|
322 |
+
" 'format': 'bestvideo+bestaudio/best',\n",
|
323 |
+
" 'outtmpl': '%(title)s.%(ext)s',\n",
|
324 |
+
" 'nocheckcertificate': True,\n",
|
325 |
+
" 'ignoreerrors': True,\n",
|
326 |
+
" 'no_warnings': True,\n",
|
327 |
+
" 'extractaudio': True,\n",
|
328 |
+
" }\n",
|
329 |
+
" with yt_dlp.YoutubeDL(ydl_opts) as ydl:\n",
|
330 |
+
" result = ydl.extract_info(url, download=True)\n",
|
331 |
+
" download_path = ydl.prepare_filename(result)\n",
|
332 |
+
" return download_path\n",
|
333 |
+
"\n",
|
334 |
+
" print('finished setting up!')\n",
|
335 |
+
" first_cell_ran = True"
|
336 |
+
]
|
337 |
+
},
|
338 |
+
{
|
339 |
+
"cell_type": "code",
|
340 |
+
"execution_count": null,
|
341 |
+
"metadata": {
|
342 |
+
"id": "4hd1TzEGCiRo",
|
343 |
+
"cellView": "form"
|
344 |
+
},
|
345 |
+
"outputs": [],
|
346 |
+
"source": [
|
347 |
+
"if 'first_cell_ran' in locals():\n",
|
348 |
+
" os.chdir(mounting_path + '/' + file_folder + '/')\n",
|
349 |
+
" #parameter markdowns-----------------\n",
|
350 |
+
" #@markdown ### Input files\n",
|
351 |
+
" #@markdown track filename: Upload your songs to the \"tracks\" folder. You may provide multiple links/files by spliting them with ;\n",
|
352 |
+
" filename = \"https://deezer.com/album/281108671\" #@param {type:\"string\"}\n",
|
353 |
+
" #@markdown onnx model (if you have your own model, upload it in models folder)\n",
|
354 |
+
" onnx = \"UVR-MDX-NET-Inst_HQ_3.onnx\" #@param [\"Kim_Inst.onnx\", \"Kim_Vocal_1.onnx\", \"Kim_Vocal_2.onnx\", \"kuielab_a_bass.onnx\", \"kuielab_a_drums.onnx\", \"kuielab_a_other.onnx\", \"kuielab_a_vocals.onnx\", \"kuielab_b_bass.onnx\", \"kuielab_b_drums.onnx\", \"kuielab_b_other.onnx\", \"kuielab_b_vocals.onnx\", \"Reverb_HQ_By_FoxJoy.onnx\", \"UVR-MDX-NET-Inst_1.onnx\", \"UVR-MDX-NET-Inst_2.onnx\", \"UVR-MDX-NET-Inst_3.onnx\", \"UVR-MDX-NET-Inst_HQ_1.onnx\", \"UVR-MDX-NET-Inst_HQ_2.onnx\", \"UVR-MDX-NET-Inst_Main.onnx\", \"UVR_MDXNET_1_9703.onnx\", \"UVR_MDXNET_2_9682.onnx\", \"UVR_MDXNET_3_9662.onnx\", \"UVR_MDXNET_9482.onnx\", \"UVR_MDXNET_KARA.onnx\", \"UVR_MDXNET_KARA_2.onnx\", \"UVR_MDXNET_Main.onnx\", \"UVR-MDX-NET-Inst_HQ_3.onnx\", \"UVR-MDX-NET-Voc_FT.onnx\"]{allow-input: true}\n",
|
355 |
+
" #@markdown process all: processes all tracks inside tracks/ folder instead. (filename will be ignored!)\n",
|
356 |
+
" process_all = False # @param{type:\"boolean\"}\n",
|
357 |
+
"\n",
|
358 |
+
"\n",
|
359 |
+
" #@markdown ### Settings\n",
|
360 |
+
" #@markdown invert: get difference between input and output (e.g get Instrumental out of Vocals)\n",
|
361 |
+
" invert = True # @param{type:\"boolean\"}\n",
|
362 |
+
" #@markdown denoise: get rid of MDX noise. (This processes input track twice)\n",
|
363 |
+
" denoise = True # @param{type:\"boolean\"}\n",
|
364 |
+
" #@markdown m_threads: like batch size, processes input wave in n threads. (beneficial for CPU)\n",
|
365 |
+
" m_threads = 2 #@param {type:\"slider\", min:1, max:8, step:1}\n",
|
366 |
+
"\n",
|
367 |
+
" #@markdown ### Custom model parameters (Only use this if you're using new/unofficial/custom models)\n",
|
368 |
+
" #@markdown Use custom model parameters. (Default: unchecked, or auto)\n",
|
369 |
+
" use_custom_parameter = False # @param{type:\"boolean\"}\n",
|
370 |
+
" #@markdown Output file suffix (usually the stem name e.g Vocals)\n",
|
371 |
+
" suffix = \"Vocals_custom\" #@param [\"Vocals\", \"Drums\", \"Bass\", \"Other\"]{allow-input: true}\n",
|
372 |
+
" suffix_invert = \"Instrumental_custom\" #@param [\"Instrumental\", \"Drumless\", \"Bassless\", \"Instruments\"]{allow-input: true}\n",
|
373 |
+
" #@markdown Model parameters\n",
|
374 |
+
" dim_f = 3072 #@param {type: \"integer\"}\n",
|
375 |
+
" dim_t = 256 #@param {type: \"integer\"}\n",
|
376 |
+
" n_fft = 6144 #@param {type: \"integer\"}\n",
|
377 |
+
" #@markdown use custom compensation: only if you have your own compensation value for your model. this still apply even if you don't have use_custom_parameter checked (Default: unchecked, or auto)\n",
|
378 |
+
" use_custom_compensation = False # @param{type:\"boolean\"}\n",
|
379 |
+
" compensation = 1.000 #@param {type: \"number\"}\n",
|
380 |
+
"\n",
|
381 |
+
" #@markdown ### Extras\n",
|
382 |
+
" #@markdown Deezer arl: paste your ARL here for deezer tracks directly!\n",
|
383 |
+
" arl = \"\" #@param {type:\"string\"}\n",
|
384 |
+
" #@markdown Track format: select track quality/format\n",
|
385 |
+
" track_format = \"FLAC\" #@param [\"FLAC\",\"MP3_320\",\"MP3_128\"]\n",
|
386 |
+
" #@markdown Print settings being used in the run\n",
|
387 |
+
" print_settings = True # @param{type:\"boolean\"}\n",
|
388 |
+
"\n",
|
389 |
+
"\n",
|
390 |
+
"\n",
|
391 |
+
" onnx = id_to_ptm(onnx)\n",
|
392 |
+
" compensation = compensation if use_custom_compensation or use_custom_parameter else None\n",
|
393 |
+
" mdx_model = prepare_mdx(use_custom_parameter, dim_f, dim_t, n_fft, compensation=compensation)\n",
|
394 |
+
"\n",
|
395 |
+
" filename_split = filename.split(';')\n",
|
396 |
+
"\n",
|
397 |
+
" usable_files = []\n",
|
398 |
+
"\n",
|
399 |
+
" if not process_all:\n",
|
400 |
+
" for fn in filename_split:\n",
|
401 |
+
" fn = fn.strip()\n",
|
402 |
+
" if is_valid_url(fn):\n",
|
403 |
+
" dm, ltype, lid = deemix.parseLink(fn)\n",
|
404 |
+
" if ltype and lid:\n",
|
405 |
+
" usable_files += download_deezer(fn, arl, track_format)\n",
|
406 |
+
" else:\n",
|
407 |
+
" print('downloading link...',end=' ')\n",
|
408 |
+
" usable_files+=[download_link(fn)]\n",
|
409 |
+
" print('done')\n",
|
410 |
+
" else:\n",
|
411 |
+
" usable_files.append(os.path.join('tracks',fn))\n",
|
412 |
+
" else:\n",
|
413 |
+
" for fn in glob.glob('tracks/*'):\n",
|
414 |
+
" usable_files.append(fn)\n",
|
415 |
+
" for filename in usable_files:\n",
|
416 |
+
" suffix_naming = suffix if use_custom_parameter else None\n",
|
417 |
+
" diff_suffix_naming = suffix_invert if use_custom_parameter else None\n",
|
418 |
+
" run_mdx(onnx, mdx_model, filename, diff=invert,suffix=suffix_naming,diff_suffix=diff_suffix_naming,denoise=denoise)\n",
|
419 |
+
"\n",
|
420 |
+
" if print_settings:\n",
|
421 |
+
" print()\n",
|
422 |
+
" print('[MDX-Net_Colab settings used]')\n",
|
423 |
+
" print(f'Model used: {onnx}')\n",
|
424 |
+
" print(f'Model MD5: {mdx.MDX.get_hash(onnx)}')\n",
|
425 |
+
" print(f'Using de-noise: {denoise}')\n",
|
426 |
+
" print(f'Model parameters:')\n",
|
427 |
+
" print(f' -dim_f: {mdx_model.dim_f}')\n",
|
428 |
+
" print(f' -dim_t: {mdx_model.dim_t}')\n",
|
429 |
+
" print(f' -n_fft: {mdx_model.n_fft}')\n",
|
430 |
+
" print(f' -compensation: {mdx_model.compensation}')\n",
|
431 |
+
" print()\n",
|
432 |
+
" print('[Input file]')\n",
|
433 |
+
" print('filename(s): ')\n",
|
434 |
+
" for filename in usable_files:\n",
|
435 |
+
" print(f' -{filename}')\n",
|
436 |
+
"\n",
|
437 |
+
" del mdx_model"
|
438 |
+
]
|
439 |
+
},
|
440 |
+
{
|
441 |
+
"cell_type": "markdown",
|
442 |
+
"source": [
|
443 |
+
"# Guide\n",
|
444 |
+
"\n",
|
445 |
+
"This tutorial guide will walk you through the steps to use the features of this Colab notebook.\n",
|
446 |
+
"\n",
|
447 |
+
"## Mount Drive\n",
|
448 |
+
"\n",
|
449 |
+
"To mount your Google Drive, follow these steps:\n",
|
450 |
+
"\n",
|
451 |
+
"1. Check the box next to \"MountDrive\" if you want to mount Google Drive.\n",
|
452 |
+
"2. Modify the \"mounting_path\" if you want to specify a different path for the drive to be mounted. **Note:** Be cautious when modifying this path as it can cause issues if not done properly.\n",
|
453 |
+
"3. Check the box next to \"Force update and disregard local changes\" if you want to discard all local modifications in your repository and replace the files with the versions from the original commit.\n",
|
454 |
+
"4. Check the box next to \"Auto Update\" if you want to automatically update without discarding your changes. Leave it unchecked if you want to manually update.\n",
|
455 |
+
"\n",
|
456 |
+
"## Input Files\n",
|
457 |
+
"\n",
|
458 |
+
"To upload your songs, follow these steps:\n",
|
459 |
+
"\n",
|
460 |
+
"1. Specify the \"track filename\" for your songs. You can provide multiple links or files by separating them with a semicolon (;).\n",
|
461 |
+
"2. Upload your songs to the \"tracks\" folder.\n",
|
462 |
+
"\n",
|
463 |
+
"## ONNX Model\n",
|
464 |
+
"\n",
|
465 |
+
"If you have your own ONNX model, follow these steps:\n",
|
466 |
+
"\n",
|
467 |
+
"1. Upload your model to the \"models\" folder.\n",
|
468 |
+
"2. Specify the \"onnx\" filename for your model.\n",
|
469 |
+
"\n",
|
470 |
+
"## Processing\n",
|
471 |
+
"\n",
|
472 |
+
"To process your tracks, follow these steps:\n",
|
473 |
+
"\n",
|
474 |
+
"1. If you want to process all tracks inside the \"tracks\" folder, check the box next to \"process_all\" and ignore the \"filename\" field.\n",
|
475 |
+
"2. Specify any additional settings you want:\n",
|
476 |
+
" - Check the box next to \"invert\" to get the difference between input and output (e.g., get Instrumental out of Vocals).\n",
|
477 |
+
" - Check the box next to \"denoise\" to get rid of MDX noise. This processes the input track twice.\n",
|
478 |
+
" - Specify custom model parameters only if you're using new/unofficial/custom models. Use the \"use_custom_parameter\" checkbox to enable this feature.\n",
|
479 |
+
" - Specify the output file suffix, which is usually the stem name (e.g., Vocals). Use the \"suffix\" field to specify the suffix for normal processing and the \"suffix_invert\" field for inverted processing.\n",
|
480 |
+
"\n",
|
481 |
+
"## Model Parameters\n",
|
482 |
+
"\n",
|
483 |
+
"Specify the following custom model parameters if applicable:\n",
|
484 |
+
"\n",
|
485 |
+
"- \"dim_f\": The value for the `dim_f` parameter.\n",
|
486 |
+
"- \"dim_t\": The value for the `dim_t` parameter.\n",
|
487 |
+
"- \"n_fft\": The value for the `n_fft` parameter.\n",
|
488 |
+
"- Check the box next to \"use_custom_compensation\" if you have your own compensation value for your model. Specify the compensation value in the \"compensation\" field.\n",
|
489 |
+
"\n",
|
490 |
+
"## Extras\n",
|
491 |
+
"\n",
|
492 |
+
"If you're working with Deezer tracks, paste your ARL (Authentication Request Library) in the \"arl\" field to directly access the tracks.\n",
|
493 |
+
"\n",
|
494 |
+
"Specify the \"Track format\" by selecting the desired quality/format for the track.\n",
|
495 |
+
"\n",
|
496 |
+
"To print the settings being used in the run, check the box next to \"print_settings\".\n",
|
497 |
+
"\n",
|
498 |
+
"That's it! You're now ready to use this Colab notebook. Enjoy!\n",
|
499 |
+
"\n",
|
500 |
+
"## For more detailed guide, proceed to this <a href=\"https://docs.google.com/document/d/17fjNvJzj8ZGSer7c7OFe_CNfUKbAxEh_OBv94ZdRG5c\">link</a>.\n",
|
501 |
+
"credits: (discord) deton24"
|
502 |
+
],
|
503 |
+
"metadata": {
|
504 |
+
"id": "tMVwX5RhZSRP"
|
505 |
+
}
|
506 |
+
}
|
507 |
+
],
|
508 |
+
"metadata": {
|
509 |
+
"accelerator": "GPU",
|
510 |
+
"colab": {
|
511 |
+
"gpuType": "T4",
|
512 |
+
"provenance": []
|
513 |
+
},
|
514 |
+
"kernelspec": {
|
515 |
+
"display_name": "Python 3",
|
516 |
+
"name": "python3"
|
517 |
+
},
|
518 |
+
"language_info": {
|
519 |
+
"name": "python"
|
520 |
+
}
|
521 |
+
},
|
522 |
+
"nbformat": 4,
|
523 |
+
"nbformat_minor": 0
|
524 |
+
}
|
MDXNet.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import soundfile as sf
|
2 |
+
import torch, pdb, os, warnings, librosa
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime as ort
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch
|
7 |
+
|
8 |
+
dim_c = 4
|
9 |
+
|
10 |
+
|
11 |
+
class Conv_TDF_net_trim:
|
12 |
+
def __init__(
|
13 |
+
self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024
|
14 |
+
):
|
15 |
+
super(Conv_TDF_net_trim, self).__init__()
|
16 |
+
|
17 |
+
self.dim_f = dim_f
|
18 |
+
self.dim_t = 2**dim_t
|
19 |
+
self.n_fft = n_fft
|
20 |
+
self.hop = hop
|
21 |
+
self.n_bins = self.n_fft // 2 + 1
|
22 |
+
self.chunk_size = hop * (self.dim_t - 1)
|
23 |
+
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(
|
24 |
+
device
|
25 |
+
)
|
26 |
+
self.target_name = target_name
|
27 |
+
self.blender = "blender" in model_name
|
28 |
+
|
29 |
+
out_c = dim_c * 4 if target_name == "*" else dim_c
|
30 |
+
self.freq_pad = torch.zeros(
|
31 |
+
[1, out_c, self.n_bins - self.dim_f, self.dim_t]
|
32 |
+
).to(device)
|
33 |
+
|
34 |
+
self.n = L // 2
|
35 |
+
|
36 |
+
def stft(self, x):
|
37 |
+
x = x.reshape([-1, self.chunk_size])
|
38 |
+
x = torch.stft(
|
39 |
+
x,
|
40 |
+
n_fft=self.n_fft,
|
41 |
+
hop_length=self.hop,
|
42 |
+
window=self.window,
|
43 |
+
center=True,
|
44 |
+
return_complex=True,
|
45 |
+
)
|
46 |
+
x = torch.view_as_real(x)
|
47 |
+
x = x.permute([0, 3, 1, 2])
|
48 |
+
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
|
49 |
+
[-1, dim_c, self.n_bins, self.dim_t]
|
50 |
+
)
|
51 |
+
return x[:, :, : self.dim_f]
|
52 |
+
|
53 |
+
def istft(self, x, freq_pad=None):
|
54 |
+
freq_pad = (
|
55 |
+
self.freq_pad.repeat([x.shape[0], 1, 1, 1])
|
56 |
+
if freq_pad is None
|
57 |
+
else freq_pad
|
58 |
+
)
|
59 |
+
x = torch.cat([x, freq_pad], -2)
|
60 |
+
c = 4 * 2 if self.target_name == "*" else 2
|
61 |
+
x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
|
62 |
+
[-1, 2, self.n_bins, self.dim_t]
|
63 |
+
)
|
64 |
+
x = x.permute([0, 2, 3, 1])
|
65 |
+
x = x.contiguous()
|
66 |
+
x = torch.view_as_complex(x)
|
67 |
+
x = torch.istft(
|
68 |
+
x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
|
69 |
+
)
|
70 |
+
return x.reshape([-1, c, self.chunk_size])
|
71 |
+
|
72 |
+
|
73 |
+
def get_models(device, dim_f, dim_t, n_fft):
|
74 |
+
return Conv_TDF_net_trim(
|
75 |
+
device=device,
|
76 |
+
model_name="Conv-TDF",
|
77 |
+
target_name="vocals",
|
78 |
+
L=11,
|
79 |
+
dim_f=dim_f,
|
80 |
+
dim_t=dim_t,
|
81 |
+
n_fft=n_fft,
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
warnings.filterwarnings("ignore")
|
86 |
+
cpu = torch.device("cpu")
|
87 |
+
if torch.cuda.is_available():
|
88 |
+
device = torch.device("cuda:0")
|
89 |
+
elif torch.backends.mps.is_available():
|
90 |
+
device = torch.device("mps")
|
91 |
+
else:
|
92 |
+
device = torch.device("cpu")
|
93 |
+
|
94 |
+
|
95 |
+
class Predictor:
|
96 |
+
def __init__(self, args):
|
97 |
+
self.args = args
|
98 |
+
self.model_ = get_models(
|
99 |
+
device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft
|
100 |
+
)
|
101 |
+
self.model = ort.InferenceSession(
|
102 |
+
os.path.join(args.onnx, self.model_.target_name + ".onnx"),
|
103 |
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
104 |
+
)
|
105 |
+
print("onnx load done")
|
106 |
+
|
107 |
+
def demix(self, mix):
|
108 |
+
samples = mix.shape[-1]
|
109 |
+
margin = self.args.margin
|
110 |
+
chunk_size = self.args.chunks * 44100
|
111 |
+
assert not margin == 0, "margin cannot be zero!"
|
112 |
+
if margin > chunk_size:
|
113 |
+
margin = chunk_size
|
114 |
+
|
115 |
+
segmented_mix = {}
|
116 |
+
|
117 |
+
if self.args.chunks == 0 or samples < chunk_size:
|
118 |
+
chunk_size = samples
|
119 |
+
|
120 |
+
counter = -1
|
121 |
+
for skip in range(0, samples, chunk_size):
|
122 |
+
counter += 1
|
123 |
+
|
124 |
+
s_margin = 0 if counter == 0 else margin
|
125 |
+
end = min(skip + chunk_size + margin, samples)
|
126 |
+
|
127 |
+
start = skip - s_margin
|
128 |
+
|
129 |
+
segmented_mix[skip] = mix[:, start:end].copy()
|
130 |
+
if end == samples:
|
131 |
+
break
|
132 |
+
|
133 |
+
sources = self.demix_base(segmented_mix, margin_size=margin)
|
134 |
+
"""
|
135 |
+
mix:(2,big_sample)
|
136 |
+
segmented_mix:offset->(2,small_sample)
|
137 |
+
sources:(1,2,big_sample)
|
138 |
+
"""
|
139 |
+
return sources
|
140 |
+
|
141 |
+
def demix_base(self, mixes, margin_size):
|
142 |
+
chunked_sources = []
|
143 |
+
progress_bar = tqdm(total=len(mixes))
|
144 |
+
progress_bar.set_description("Processing")
|
145 |
+
for mix in mixes:
|
146 |
+
cmix = mixes[mix]
|
147 |
+
sources = []
|
148 |
+
n_sample = cmix.shape[1]
|
149 |
+
model = self.model_
|
150 |
+
trim = model.n_fft // 2
|
151 |
+
gen_size = model.chunk_size - 2 * trim
|
152 |
+
pad = gen_size - n_sample % gen_size
|
153 |
+
mix_p = np.concatenate(
|
154 |
+
(np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
|
155 |
+
)
|
156 |
+
mix_waves = []
|
157 |
+
i = 0
|
158 |
+
while i < n_sample + pad:
|
159 |
+
waves = np.array(mix_p[:, i : i + model.chunk_size])
|
160 |
+
mix_waves.append(waves)
|
161 |
+
i += gen_size
|
162 |
+
mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(cpu)
|
163 |
+
with torch.no_grad():
|
164 |
+
_ort = self.model
|
165 |
+
spek = model.stft(mix_waves)
|
166 |
+
if self.args.denoise:
|
167 |
+
spec_pred = (
|
168 |
+
-_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
|
169 |
+
+ _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
|
170 |
+
)
|
171 |
+
tar_waves = model.istft(torch.tensor(spec_pred))
|
172 |
+
else:
|
173 |
+
tar_waves = model.istft(
|
174 |
+
torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
|
175 |
+
)
|
176 |
+
tar_signal = (
|
177 |
+
tar_waves[:, :, trim:-trim]
|
178 |
+
.transpose(0, 1)
|
179 |
+
.reshape(2, -1)
|
180 |
+
.numpy()[:, :-pad]
|
181 |
+
)
|
182 |
+
|
183 |
+
start = 0 if mix == 0 else margin_size
|
184 |
+
end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
|
185 |
+
if margin_size == 0:
|
186 |
+
end = None
|
187 |
+
sources.append(tar_signal[:, start:end])
|
188 |
+
|
189 |
+
progress_bar.update(1)
|
190 |
+
|
191 |
+
chunked_sources.append(sources)
|
192 |
+
_sources = np.concatenate(chunked_sources, axis=-1)
|
193 |
+
# del self.model
|
194 |
+
progress_bar.close()
|
195 |
+
return _sources
|
196 |
+
|
197 |
+
def prediction(self, m, vocal_root, others_root, format):
|
198 |
+
os.makedirs(vocal_root, exist_ok=True)
|
199 |
+
os.makedirs(others_root, exist_ok=True)
|
200 |
+
basename = os.path.basename(m)
|
201 |
+
mix, rate = librosa.load(m, mono=False, sr=44100)
|
202 |
+
if mix.ndim == 1:
|
203 |
+
mix = np.asfortranarray([mix, mix])
|
204 |
+
mix = mix.T
|
205 |
+
sources = self.demix(mix.T)
|
206 |
+
opt = sources[0].T
|
207 |
+
if format in ["wav", "flac"]:
|
208 |
+
sf.write(
|
209 |
+
"%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate
|
210 |
+
)
|
211 |
+
sf.write("%s/%s_others.%s" % (others_root, basename, format), opt, rate)
|
212 |
+
else:
|
213 |
+
path_vocal = "%s/%s_main_vocal.wav" % (vocal_root, basename)
|
214 |
+
path_other = "%s/%s_others.wav" % (others_root, basename)
|
215 |
+
sf.write(path_vocal, mix - opt, rate)
|
216 |
+
sf.write(path_other, opt, rate)
|
217 |
+
if os.path.exists(path_vocal):
|
218 |
+
os.system(
|
219 |
+
"ffmpeg -i %s -vn %s -q:a 2 -y"
|
220 |
+
% (path_vocal, path_vocal[:-4] + ".%s" % format)
|
221 |
+
)
|
222 |
+
if os.path.exists(path_other):
|
223 |
+
os.system(
|
224 |
+
"ffmpeg -i %s -vn %s -q:a 2 -y"
|
225 |
+
% (path_other, path_other[:-4] + ".%s" % format)
|
226 |
+
)
|
227 |
+
|
228 |
+
|
229 |
+
class MDXNetDereverb:
|
230 |
+
def __init__(self, chunks):
|
231 |
+
self.onnx = "uvr5_weights/onnx_dereverb_By_FoxJoy"
|
232 |
+
self.shifts = 10 #'Predict with randomised equivariant stabilisation'
|
233 |
+
self.mixing = "min_mag" # ['default','min_mag','max_mag']
|
234 |
+
self.chunks = chunks
|
235 |
+
self.margin = 44100
|
236 |
+
self.dim_t = 9
|
237 |
+
self.dim_f = 3072
|
238 |
+
self.n_fft = 6144
|
239 |
+
self.denoise = True
|
240 |
+
self.pred = Predictor(self)
|
241 |
+
|
242 |
+
def _path_audio_(self, input, vocal_root, others_root, format):
|
243 |
+
self.pred.prediction(input, vocal_root, others_root, format)
|
244 |
+
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
dereverb = MDXNetDereverb(15)
|
248 |
+
from time import time as ttime
|
249 |
+
|
250 |
+
t0 = ttime()
|
251 |
+
dereverb._path_audio_(
|
252 |
+
"雪雪伴奏对消HP5.wav",
|
253 |
+
"vocal",
|
254 |
+
"others",
|
255 |
+
)
|
256 |
+
t1 = ttime()
|
257 |
+
print(t1 - t0)
|
258 |
+
|
259 |
+
|
260 |
+
"""
|
261 |
+
|
262 |
+
runtime\python.exe MDXNet.py
|
263 |
+
|
264 |
+
6G:
|
265 |
+
15/9:0.8G->6.8G
|
266 |
+
14:0.8G->6.5G
|
267 |
+
25:炸
|
268 |
+
|
269 |
+
half15:0.7G->6.6G,22.69s
|
270 |
+
fp32-15:0.7G->6.6G,20.85s
|
271 |
+
|
272 |
+
"""
|
Makefile
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY:
|
2 |
+
.ONESHELL:
|
3 |
+
|
4 |
+
help: ## Show this help and exit
|
5 |
+
@grep -hE '^[A-Za-z0-9_ \-]*?:.*##.*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
6 |
+
|
7 |
+
install: ## Install dependencies (Do everytime you start up a paperspace machine)
|
8 |
+
apt-get -y install build-essential python3-dev ffmpeg
|
9 |
+
pip install --upgrade setuptools wheel
|
10 |
+
pip install --upgrade pip
|
11 |
+
pip install faiss-gpu fairseq gradio ffmpeg ffmpeg-python praat-parselmouth pyworld numpy==1.23.5 numba==0.56.4 librosa==0.9.1
|
12 |
+
pip install -r requirements.txt
|
13 |
+
pip install --upgrade lxml
|
14 |
+
apt-get update
|
15 |
+
apt -y install -qq aria2
|
16 |
+
|
17 |
+
basev1: ## Download version 1 pre-trained models (Do only once after cloning the fork)
|
18 |
+
mkdir -p pretrained uvr5_weights
|
19 |
+
git pull
|
20 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/D32k.pth -d pretrained -o D32k.pth
|
21 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/D40k.pth -d pretrained -o D40k.pth
|
22 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/D48k.pth -d pretrained -o D48k.pth
|
23 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/G32k.pth -d pretrained -o G32k.pth
|
24 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/G40k.pth -d pretrained -o G40k.pth
|
25 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/G48k.pth -d pretrained -o G48k.pth
|
26 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0D32k.pth -d pretrained -o f0D32k.pth
|
27 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0D40k.pth -d pretrained -o f0D40k.pth
|
28 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0D48k.pth -d pretrained -o f0D48k.pth
|
29 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0G32k.pth -d pretrained -o f0G32k.pth
|
30 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0G40k.pth -d pretrained -o f0G40k.pth
|
31 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0G48k.pth -d pretrained -o f0G48k.pth
|
32 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2-人声vocals+非人声instrumentals.pth -d uvr5_weights -o HP2-人声vocals+非人声instrumentals.pth
|
33 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5-主旋律人声vocals+其他instrumentals.pth -d uvr5_weights -o HP5-主旋律人声vocals+其他instrumentals.pth
|
34 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt -d ./ -o hubert_base.pt
|
35 |
+
|
36 |
+
basev2: ## Download version 2 pre-trained models (Do only once after cloning the fork)
|
37 |
+
mkdir -p pretrained_v2 uvr5_weights
|
38 |
+
git pull
|
39 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/D32k.pth -d pretrained_v2 -o D32k.pth
|
40 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/D40k.pth -d pretrained_v2 -o D40k.pth
|
41 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/D48k.pth -d pretrained_v2 -o D48k.pth
|
42 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/G32k.pth -d pretrained_v2 -o G32k.pth
|
43 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/G40k.pth -d pretrained_v2 -o G40k.pth
|
44 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/G48k.pth -d pretrained_v2 -o G48k.pth
|
45 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0D32k.pth -d pretrained_v2 -o f0D32k.pth
|
46 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0D40k.pth -d pretrained_v2 -o f0D40k.pth
|
47 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0D48k.pth -d pretrained_v2 -o f0D48k.pth
|
48 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0G32k.pth -d pretrained_v2 -o f0G32k.pth
|
49 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0G40k.pth -d pretrained_v2 -o f0G40k.pth
|
50 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0G48k.pth -d pretrained_v2 -o f0G48k.pth
|
51 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2-人声vocals+非人声instrumentals.pth -d uvr5_weights -o HP2-人声vocals+非人声instrumentals.pth
|
52 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5-主旋律人声vocals+其他instrumentals.pth -d uvr5_weights -o HP5-主旋律人声vocals+其他instrumentals.pth
|
53 |
+
aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt -d ./ -o hubert_base.pt
|
54 |
+
|
55 |
+
run-ui: ## Run the python GUI
|
56 |
+
python infer-web.py --paperspace --pycmd python
|
57 |
+
|
58 |
+
run-cli: ## Run the python CLI
|
59 |
+
python infer-web.py --pycmd python --is_cli
|
60 |
+
|
61 |
+
tensorboard: ## Start the tensorboard (Run on separate terminal)
|
62 |
+
echo https://tensorboard-$$(hostname).clg07azjl.paperspacegradient.com
|
63 |
+
tensorboard --logdir logs --bind_all
|
README.md
CHANGED
@@ -1,12 +1,222 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🍏 Applio-RVC-Fork
|
2 |
+
Applio is a user-friendly fork of Mangio-RVC-Fork/RVC, designed to provide an intuitive interface, especially for newcomers.
|
3 |
+
|
4 |
+
## 📎 Links
|
5 |
+
[![Discord](https://img.shields.io/badge/SUPPORT_DISCORD-37a779?style=for-the-badge)](https://discord.gg/IAHispano)
|
6 |
+
[![Google Colab](https://img.shields.io/badge/GOOGLE_COLAB-37a779?style=for-the-badge)](https://colab.research.google.com/drive/157pUQep6txJOYModYFqvz_5OJajeh7Ii)
|
7 |
+
|
8 |
+
## 📚 Table of Contents
|
9 |
+
1. [Improvements of Applio Over RVC](#-improvements-of-applio-over-rvc)
|
10 |
+
2. [Additional Features of This Repository](#️-additional-features-of-this-repository)
|
11 |
+
3. [Planned Features for Future Development](#️-planned-features-for-future-development)
|
12 |
+
4. [Installation](#-installation)
|
13 |
+
5. [Running the Web GUI (Inference & Train)](#-running-the-web-gui-inference--train)
|
14 |
+
6. [Running the CLI (Inference & Train)](#-running-the-cli-inference--train)
|
15 |
+
7. [Credits](#credits)
|
16 |
+
8. [Thanks to all RVC and Mangio contributors](#thanks-to-all-rvc-and-mangio-contributors)
|
17 |
+
|
18 |
+
|
19 |
+
## 🎯 Improvements of Applio Over RVC
|
20 |
+
### f0 Inference Algorithm Overhaul
|
21 |
+
- Applio features a comprehensive overhaul of the f0 inference algorithm, including:
|
22 |
+
- Addition of the pyworld dio f0 method.
|
23 |
+
- Alternative method for calculating crepe f0.
|
24 |
+
- Introduction of the torchcrepe crepe-tiny model.
|
25 |
+
- Customizable crepe_hop_length for the crepe algorithm via both the web GUI and CLI.
|
26 |
+
|
27 |
+
### f0 Crepe Pitch Extraction for Training
|
28 |
+
- Works on paperspace machines but not local MacOS/Windows machines (Potential memory leak).
|
29 |
+
|
30 |
+
### Paperspace Integration
|
31 |
+
- Applio seamlessly integrates with Paperspace, providing the following features:
|
32 |
+
- Paperspace argument on infer-web.py (--paperspace) for sharing a Gradio link.
|
33 |
+
- A dedicated make file tailored for Paperspace users.
|
34 |
+
|
35 |
+
### Access to Tensorboard
|
36 |
+
- Applio grants easy access to Tensorboard via a Makefile and a Python script.
|
37 |
+
|
38 |
+
### CLI Functionality
|
39 |
+
- Applio introduces command-line interface (CLI) functionality, with the addition of the --is_cli flag in infer-web.py for CLI system usage.
|
40 |
+
|
41 |
+
### f0 Hybrid Estimation Method
|
42 |
+
- Applio offers a novel f0 hybrid estimation method by calculating nanmedian for a specified array of f0 methods, ensuring the best results from multiple methods (CLI exclusive).
|
43 |
+
- This hybrid estimation method is also available for f0 feature extraction during training.
|
44 |
+
|
45 |
+
### UI Changes
|
46 |
+
#### Inference:
|
47 |
+
- A complete interface redesign enhances user experience, with notable features such as:
|
48 |
+
- Audio recording directly from the interface.
|
49 |
+
- Convenient drop-down menus for audio and .index file selection.
|
50 |
+
- An advanced settings section with new features like autotune and formant shifting.
|
51 |
+
|
52 |
+
#### Training:
|
53 |
+
- Improved training features include:
|
54 |
+
- A total epoch slider now limited to 10,000.
|
55 |
+
- Increased save frequency limit to 100.
|
56 |
+
- Default recommended options for smoother setup.
|
57 |
+
- Better adaptation to high-resolution screens.
|
58 |
+
- A drop-down menu for dataset selection.
|
59 |
+
- Enhanced saving system options, including Save all files, Save G and D files, and Save model for inference.
|
60 |
+
|
61 |
+
#### UVR:
|
62 |
+
- Applio ensures compatibility with all VR/MDX models for an extended range of possibilities.
|
63 |
+
|
64 |
+
#### TTS (Text-to-Speech, New):
|
65 |
+
- Introducing a new Text-to-Speech (TTS) feature using RVC models.
|
66 |
+
- Support for multiple languages and Edge-tts/Bark-tts.
|
67 |
+
|
68 |
+
#### Resources (New):
|
69 |
+
- Users can now upload models, backups, datasets, and audios from various storage services like Drive, Huggingface, Discord, and more.
|
70 |
+
- Download audios from YouTube with the ability to automatically separate instrumental and vocals, offering advanced options and UVR support.
|
71 |
+
|
72 |
+
#### Extra (New):
|
73 |
+
- Combine instrumental and vocals with ease, including independent volume control for each track and the option to add effects like reverb, compressor, and noise gate.
|
74 |
+
- Significant improvements in the processing interface, allowing tasks such as merging models, modifying information, obtaining information, or extracting models effortlessly.
|
75 |
+
|
76 |
+
## ⚙️ Additional Features of This Repository
|
77 |
+
|
78 |
+
In addition to the aforementioned improvements, this repository offers the following features:
|
79 |
+
|
80 |
+
### Enhanced Tone Leakage Reduction
|
81 |
+
- Implements tone leakage reduction by replacing source features with training-set features using top1 retrieval. This helps in achieving cleaner audio results.
|
82 |
+
|
83 |
+
### Efficient Training
|
84 |
+
- Provides a seamless and speedy training experience, even on relatively modest graphics cards. The system is optimized for efficient resource utilization.
|
85 |
+
|
86 |
+
### Data Efficiency
|
87 |
+
- Supports training with a small dataset, yielding commendable results, especially with audio clips of at least 10 minutes of low-noise speech.
|
88 |
+
|
89 |
+
## 🛠️ Planned Features for Future Development
|
90 |
+
As part of the ongoing development of this fork, the following features are planned to be added:
|
91 |
+
|
92 |
+
- Incorporating an inference batcher script based on user feedback. This enhancement will allow for processing 30-second audio samples at a time, improving output quality and preventing memory errors during inference.
|
93 |
+
- Implementing an automatic removal mechanism for old generations to optimize storage space usage. This feature ensures that the repository remains efficient and organized over time.
|
94 |
+
- Streamlining the training process for Paperspace machines to further improve efficiency and resource utilization during training tasks.
|
95 |
+
|
96 |
+
## Compatibility
|
97 |
+
- AMD/Intel graphics cards acceleration supported.
|
98 |
+
- Intel ARC graphics cards acceleration with IPEX supported.
|
99 |
+
|
100 |
+
## ✨ Installation
|
101 |
+
|
102 |
+
### Automatic installation (Windows):
|
103 |
+
To quickly and effortlessly install Applio along with all the necessary models and configurations on Windows, you can use the [install_Applio.bat](https://github.com/IAHispano/Applio-RVC-Fork/releases) script available in the releases section.
|
104 |
+
|
105 |
+
### Manual installation (Windows/MacOS):
|
106 |
+
**Note for MacOS Users**: When using `faiss 1.7.2` under MacOS, you may encounter a Segmentation Fault: 11 error. To resolve this issue, install `faiss-cpu 1.7.0` using the following command if you're installing it manually with pip:
|
107 |
+
```bash
|
108 |
+
pip install faiss-cpu==1.7.0
|
109 |
+
```
|
110 |
+
Additionally, you can install Swig on MacOS using brew:
|
111 |
+
```bash
|
112 |
+
brew install swig
|
113 |
+
```
|
114 |
+
|
115 |
+
Install requirements:
|
116 |
+
*Using pip (Python 3.9.8 is stable with this fork)*
|
117 |
+
```bash
|
118 |
+
pip install -r requirements.txt
|
119 |
+
```
|
120 |
+
|
121 |
+
### Manual installation (Paperspace):
|
122 |
+
```bash
|
123 |
+
cd Applio-RVC-Fork
|
124 |
+
make install # Do this everytime you start your paperspace machine
|
125 |
+
```
|
126 |
+
### You can also use pip to install them:
|
127 |
+
```bash
|
128 |
+
|
129 |
+
for Nvidia graphics cards
|
130 |
+
pip install -r requirements.txt
|
131 |
+
|
132 |
+
for AMD/Intel graphics cards:
|
133 |
+
pip install -r requirements-dml.txt
|
134 |
+
|
135 |
+
for Intel ARC graphics cards on Linux / WSL using Python 3.10:
|
136 |
+
pip install -r requirements-ipex.txt
|
137 |
+
|
138 |
+
```
|
139 |
+
|
140 |
+
## 🪄 Running the Web GUI (Inference & Train)
|
141 |
+
*Use --paperspace or --colab if on cloud system.*
|
142 |
+
```bash
|
143 |
+
python infer-web.py --pycmd python --port 3000
|
144 |
+
```
|
145 |
+
|
146 |
+
## 💻 Running the CLI (Inference & Train)
|
147 |
+
```bash
|
148 |
+
python infer-web.py --pycmd python --is_cli
|
149 |
+
```
|
150 |
+
|
151 |
+
```bash
|
152 |
+
Mangio-RVC-Fork v2 CLI App!
|
153 |
+
|
154 |
+
Welcome to the CLI version of RVC. Please read the documentation on https://github.com/Mangio621/Mangio-RVC-Fork (README.MD) to understand how to use this app.
|
155 |
+
|
156 |
+
You are currently in 'HOME':
|
157 |
+
go home : Takes you back to home with a navigation list.
|
158 |
+
go infer : Takes you to inference command execution.
|
159 |
+
|
160 |
+
go pre-process : Takes you to training step.1) pre-process command execution.
|
161 |
+
go extract-feature : Takes you to training step.2) extract-feature command execution.
|
162 |
+
go train : Takes you to training step.3) being or continue training command execution.
|
163 |
+
go train-feature : Takes you to the train feature index command execution.
|
164 |
+
|
165 |
+
go extract-model : Takes you to the extract small model command execution.
|
166 |
+
|
167 |
+
HOME:
|
168 |
+
```
|
169 |
+
|
170 |
+
Typing 'go infer' for example will take you to the infer page where you can then enter in your arguments that you wish to use for that specific page. For example typing 'go infer' will take you here:
|
171 |
+
|
172 |
+
```bash
|
173 |
+
HOME: go infer
|
174 |
+
You are currently in 'INFER':
|
175 |
+
arg 1) model name with .pth in ./weights: mi-test.pth
|
176 |
+
arg 2) source audio path: myFolder\MySource.wav
|
177 |
+
arg 3) output file name to be placed in './audio-outputs': MyTest.wav
|
178 |
+
arg 4) feature index file path: logs/mi-test/added_IVF3042_Flat_nprobe_1.index
|
179 |
+
arg 5) speaker id: 0
|
180 |
+
arg 6) transposition: 0
|
181 |
+
arg 7) f0 method: harvest (pm, harvest, crepe, crepe-tiny)
|
182 |
+
arg 8) crepe hop length: 160
|
183 |
+
arg 9) harvest median filter radius: 3 (0-7)
|
184 |
+
arg 10) post resample rate: 0
|
185 |
+
arg 11) mix volume envelope: 1
|
186 |
+
arg 12) feature index ratio: 0.78 (0-1)
|
187 |
+
arg 13) Voiceless Consonant Protection (Less Artifact): 0.33 (Smaller number = more protection. 0.50 means Dont Use.)
|
188 |
+
|
189 |
+
Example: mi-test.pth saudio/Sidney.wav myTest.wav logs/mi-test/added_index.index 0 -2 harvest 160 3 0 1 0.95 0.33
|
190 |
+
|
191 |
+
INFER: <INSERT ARGUMENTS HERE OR COPY AND PASTE THE EXAMPLE>
|
192 |
+
```
|
193 |
+
## 🏆 Credits
|
194 |
+
Applio owes its existence to the collaborative efforts of various repositories, including Mangio-RVC-Fork, and all the other credited contributors. Without their contributions, Applio would not have been possible. Therefore, we kindly request that if you appreciate the work we've accomplished, you consider exploring the projects mentioned in our credits.
|
195 |
+
|
196 |
+
Our goal is not to supplant RVC or Mangio; rather, we aim to provide a contemporary and up-to-date alternative for the entire community.
|
197 |
+
|
198 |
+
+ [Retrieval-based-Voice-Conversion-WebUI](Retrieval-based-Voice-Conversion-WebUI)
|
199 |
+
+ [Mangio-RVC-Fork](https://github.com/Mangio621/Mangio-RVC-Fork)
|
200 |
+
+ [RVG_tts](https://github.com/Foxify52/RVG_tts)
|
201 |
+
+ [ContentVec](https://github.com/auspicious3000/contentvec/)
|
202 |
+
+ [VITS](https://github.com/jaywalnut310/vits)
|
203 |
+
+ [HIFIGAN](https://github.com/jik876/hifi-gan)
|
204 |
+
+ [Gradio](https://github.com/gradio-app/gradio)
|
205 |
+
+ [FFmpeg](https://github.com/FFmpeg/FFmpeg)
|
206 |
+
+ [Ultimate Vocal Remover](https://github.com/Anjok07/ultimatevocalremovergui)
|
207 |
+
+ [audio-slicer](https://github.com/openvpi/audio-slicer)
|
208 |
+
+ [Vocal pitch extraction:RMVPE](https://github.com/Dream-High/RMVPE)
|
209 |
+
|
210 |
+
|
211 |
+
## 🙏 Thanks to all RVC, Mangio and Applio contributors
|
212 |
+
<a href="https://github.com/liujing04/Retrieval-based-Voice-Conversion-WebUI/graphs/contributors" target="_blank">
|
213 |
+
<img src="https://contrib.rocks/image?repo=liujing04/Retrieval-based-Voice-Conversion-WebUI" />
|
214 |
+
</a>
|
215 |
+
|
216 |
+
<a href="https://github.com/Mangio621/Mangio-RVC-Fork/graphs/contributors" target="_blank">
|
217 |
+
<img src="https://contrib.rocks/image?repo=Mangio621/Mangio-RVC-Fork" />
|
218 |
+
</a>
|
219 |
+
|
220 |
+
<a href="https://github.com/IAHispano/Applio-RVC-Fork/graphs/contributors" target="_blank">
|
221 |
+
<img src="https://contrib.rocks/image?repo=IAHispano/Applio-RVC-Fork" />
|
222 |
+
</a>
|
assets/hubert/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
assets/pretrained/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
assets/pretrained_v2/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
assets/rmvpe/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
assets/uvr5_weights/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
assets/weights/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
audioEffects.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pedalboard import Pedalboard, Compressor, Reverb, NoiseGate
|
2 |
+
from pedalboard.io import AudioFile
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
now_dir = os.getcwd()
|
6 |
+
sys.path.append(now_dir)
|
7 |
+
from i18n import I18nAuto
|
8 |
+
i18n = I18nAuto()
|
9 |
+
from pydub import AudioSegment
|
10 |
+
import numpy as np
|
11 |
+
import soundfile as sf
|
12 |
+
from pydub.playback import play
|
13 |
+
|
14 |
+
def process_audio(input_path, output_path, reverb_enabled, compressor_enabled, noise_gate_enabled, ):
|
15 |
+
print(reverb_enabled)
|
16 |
+
print(compressor_enabled)
|
17 |
+
print(noise_gate_enabled)
|
18 |
+
effects = []
|
19 |
+
if reverb_enabled:
|
20 |
+
effects.append(Reverb(room_size=0.01))
|
21 |
+
if compressor_enabled:
|
22 |
+
effects.append(Compressor(threshold_db=-10, ratio=25))
|
23 |
+
if noise_gate_enabled:
|
24 |
+
effects.append(NoiseGate(threshold_db=-16, ratio=1.5, release_ms=250))
|
25 |
+
|
26 |
+
board = Pedalboard(effects)
|
27 |
+
|
28 |
+
with AudioFile(input_path) as f:
|
29 |
+
with AudioFile(output_path, 'w', f.samplerate, f.num_channels) as o:
|
30 |
+
while f.tell() < f.frames:
|
31 |
+
chunk = f.read(f.samplerate)
|
32 |
+
effected = board(chunk, f.samplerate, reset=False)
|
33 |
+
o.write(effected)
|
34 |
+
|
35 |
+
result = i18n("Processed audio saved at: ") + output_path
|
36 |
+
print(result)
|
37 |
+
return output_path
|
audios/.gitignore
ADDED
File without changes
|
colab_for_mdx.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import gc
|
4 |
+
import psutil
|
5 |
+
import requests
|
6 |
+
import subprocess
|
7 |
+
import time
|
8 |
+
import logging
|
9 |
+
import sys
|
10 |
+
import shutil
|
11 |
+
now_dir = os.getcwd()
|
12 |
+
sys.path.append(now_dir)
|
13 |
+
first_cell_executed = False
|
14 |
+
file_folder = "Colab-for-MDX_B"
|
15 |
+
def first_cell_ran():
|
16 |
+
global first_cell_executed
|
17 |
+
if first_cell_executed:
|
18 |
+
#print("The 'first_cell_ran' function has already been executed.")
|
19 |
+
return
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
first_cell_executed = True
|
24 |
+
os.makedirs("tmp_models", exist_ok=True)
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
class hide_opt: # hide outputs
|
29 |
+
def __enter__(self):
|
30 |
+
self._original_stdout = sys.stdout
|
31 |
+
sys.stdout = open(os.devnull, "w")
|
32 |
+
|
33 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
34 |
+
sys.stdout.close()
|
35 |
+
sys.stdout = self._original_stdout
|
36 |
+
|
37 |
+
def get_size(bytes, suffix="B"): # read ram
|
38 |
+
global svmem
|
39 |
+
factor = 1024
|
40 |
+
for unit in ["", "K", "M", "G", "T", "P"]:
|
41 |
+
if bytes < factor:
|
42 |
+
return f"{bytes:.2f}{unit}{suffix}"
|
43 |
+
bytes /= factor
|
44 |
+
svmem = psutil.virtual_memory()
|
45 |
+
|
46 |
+
|
47 |
+
def use_uvr_without_saving():
|
48 |
+
print("Notice: files won't be saved to personal drive.")
|
49 |
+
print(f"Downloading {file_folder}...", end=" ")
|
50 |
+
with hide_opt():
|
51 |
+
#os.chdir(mounting_path)
|
52 |
+
items_to_move = ["demucs", "diffq","julius","model","separated","tracks","mdx.py","MDX-Net_Colab.ipynb"]
|
53 |
+
subprocess.run(["git", "clone", "https://github.com/NaJeongMo/Colab-for-MDX_B.git"])
|
54 |
+
for item_name in items_to_move:
|
55 |
+
item_path = os.path.join(file_folder, item_name)
|
56 |
+
if os.path.exists(item_path):
|
57 |
+
if os.path.isfile(item_path):
|
58 |
+
shutil.move(item_path, now_dir)
|
59 |
+
elif os.path.isdir(item_path):
|
60 |
+
shutil.move(item_path, now_dir)
|
61 |
+
try:
|
62 |
+
shutil.rmtree(file_folder)
|
63 |
+
except PermissionError:
|
64 |
+
print(f"No se pudo eliminar la carpeta {file_folder}. Puede estar relacionada con Git.")
|
65 |
+
|
66 |
+
|
67 |
+
use_uvr_without_saving()
|
68 |
+
print("done!")
|
69 |
+
if not os.path.exists("tracks"):
|
70 |
+
os.mkdir("tracks")
|
71 |
+
first_cell_ran()
|
configs/32k.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"seed": 1234,
|
5 |
+
"epochs": 20000,
|
6 |
+
"learning_rate": 1e-4,
|
7 |
+
"betas": [0.8, 0.99],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"batch_size": 4,
|
10 |
+
"fp16_run": false,
|
11 |
+
"lr_decay": 0.999875,
|
12 |
+
"segment_size": 12800,
|
13 |
+
"init_lr_ratio": 1,
|
14 |
+
"warmup_epochs": 0,
|
15 |
+
"c_mel": 45,
|
16 |
+
"c_kl": 1.0
|
17 |
+
},
|
18 |
+
"data": {
|
19 |
+
"max_wav_value": 32768.0,
|
20 |
+
"sampling_rate": 32000,
|
21 |
+
"filter_length": 1024,
|
22 |
+
"hop_length": 320,
|
23 |
+
"win_length": 1024,
|
24 |
+
"n_mel_channels": 80,
|
25 |
+
"mel_fmin": 0.0,
|
26 |
+
"mel_fmax": null
|
27 |
+
},
|
28 |
+
"model": {
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
+
"resblock_dilation_sizes": [
|
39 |
+
[1, 3, 5],
|
40 |
+
[1, 3, 5],
|
41 |
+
[1, 3, 5]
|
42 |
+
],
|
43 |
+
"upsample_rates": [10, 4, 2, 2, 2],
|
44 |
+
"upsample_initial_channel": 512,
|
45 |
+
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
|
46 |
+
"use_spectral_norm": false,
|
47 |
+
"gin_channels": 256,
|
48 |
+
"spk_embed_dim": 109
|
49 |
+
}
|
50 |
+
}
|
configs/32k_v2.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"seed": 1234,
|
5 |
+
"epochs": 20000,
|
6 |
+
"learning_rate": 1e-4,
|
7 |
+
"betas": [0.8, 0.99],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"batch_size": 4,
|
10 |
+
"fp16_run": true,
|
11 |
+
"lr_decay": 0.999875,
|
12 |
+
"segment_size": 12800,
|
13 |
+
"init_lr_ratio": 1,
|
14 |
+
"warmup_epochs": 0,
|
15 |
+
"c_mel": 45,
|
16 |
+
"c_kl": 1.0
|
17 |
+
},
|
18 |
+
"data": {
|
19 |
+
"max_wav_value": 32768.0,
|
20 |
+
"sampling_rate": 32000,
|
21 |
+
"filter_length": 1024,
|
22 |
+
"hop_length": 320,
|
23 |
+
"win_length": 1024,
|
24 |
+
"n_mel_channels": 80,
|
25 |
+
"mel_fmin": 0.0,
|
26 |
+
"mel_fmax": null
|
27 |
+
},
|
28 |
+
"model": {
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
+
"resblock_dilation_sizes": [
|
39 |
+
[1, 3, 5],
|
40 |
+
[1, 3, 5],
|
41 |
+
[1, 3, 5]
|
42 |
+
],
|
43 |
+
"upsample_rates": [10, 8, 2, 2],
|
44 |
+
"upsample_initial_channel": 512,
|
45 |
+
"upsample_kernel_sizes": [20, 16, 4, 4],
|
46 |
+
"use_spectral_norm": false,
|
47 |
+
"gin_channels": 256,
|
48 |
+
"spk_embed_dim": 109
|
49 |
+
}
|
50 |
+
}
|
configs/40k.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"seed": 1234,
|
5 |
+
"epochs": 20000,
|
6 |
+
"learning_rate": 1e-4,
|
7 |
+
"betas": [0.8, 0.99],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"batch_size": 4,
|
10 |
+
"fp16_run": false,
|
11 |
+
"lr_decay": 0.999875,
|
12 |
+
"segment_size": 12800,
|
13 |
+
"init_lr_ratio": 1,
|
14 |
+
"warmup_epochs": 0,
|
15 |
+
"c_mel": 45,
|
16 |
+
"c_kl": 1.0
|
17 |
+
},
|
18 |
+
"data": {
|
19 |
+
"max_wav_value": 32768.0,
|
20 |
+
"sampling_rate": 40000,
|
21 |
+
"filter_length": 2048,
|
22 |
+
"hop_length": 400,
|
23 |
+
"win_length": 2048,
|
24 |
+
"n_mel_channels": 125,
|
25 |
+
"mel_fmin": 0.0,
|
26 |
+
"mel_fmax": null
|
27 |
+
},
|
28 |
+
"model": {
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
+
"resblock_dilation_sizes": [
|
39 |
+
[1, 3, 5],
|
40 |
+
[1, 3, 5],
|
41 |
+
[1, 3, 5]
|
42 |
+
],
|
43 |
+
"upsample_rates": [10, 10, 2, 2],
|
44 |
+
"upsample_initial_channel": 512,
|
45 |
+
"upsample_kernel_sizes": [16, 16, 4, 4],
|
46 |
+
"use_spectral_norm": false,
|
47 |
+
"gin_channels": 256,
|
48 |
+
"spk_embed_dim": 109
|
49 |
+
}
|
50 |
+
}
|
configs/48k.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"seed": 1234,
|
5 |
+
"epochs": 20000,
|
6 |
+
"learning_rate": 1e-4,
|
7 |
+
"betas": [0.8, 0.99],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"batch_size": 4,
|
10 |
+
"fp16_run": false,
|
11 |
+
"lr_decay": 0.999875,
|
12 |
+
"segment_size": 11520,
|
13 |
+
"init_lr_ratio": 1,
|
14 |
+
"warmup_epochs": 0,
|
15 |
+
"c_mel": 45,
|
16 |
+
"c_kl": 1.0
|
17 |
+
},
|
18 |
+
"data": {
|
19 |
+
"max_wav_value": 32768.0,
|
20 |
+
"sampling_rate": 48000,
|
21 |
+
"filter_length": 2048,
|
22 |
+
"hop_length": 480,
|
23 |
+
"win_length": 2048,
|
24 |
+
"n_mel_channels": 128,
|
25 |
+
"mel_fmin": 0.0,
|
26 |
+
"mel_fmax": null
|
27 |
+
},
|
28 |
+
"model": {
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
+
"resblock_dilation_sizes": [
|
39 |
+
[1, 3, 5],
|
40 |
+
[1, 3, 5],
|
41 |
+
[1, 3, 5]
|
42 |
+
],
|
43 |
+
"upsample_rates": [10, 6, 2, 2, 2],
|
44 |
+
"upsample_initial_channel": 512,
|
45 |
+
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
|
46 |
+
"use_spectral_norm": false,
|
47 |
+
"gin_channels": 256,
|
48 |
+
"spk_embed_dim": 109
|
49 |
+
}
|
50 |
+
}
|
configs/48k_v2.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"seed": 1234,
|
5 |
+
"epochs": 20000,
|
6 |
+
"learning_rate": 1e-4,
|
7 |
+
"betas": [0.8, 0.99],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"batch_size": 4,
|
10 |
+
"fp16_run": true,
|
11 |
+
"lr_decay": 0.999875,
|
12 |
+
"segment_size": 17280,
|
13 |
+
"init_lr_ratio": 1,
|
14 |
+
"warmup_epochs": 0,
|
15 |
+
"c_mel": 45,
|
16 |
+
"c_kl": 1.0
|
17 |
+
},
|
18 |
+
"data": {
|
19 |
+
"max_wav_value": 32768.0,
|
20 |
+
"sampling_rate": 48000,
|
21 |
+
"filter_length": 2048,
|
22 |
+
"hop_length": 480,
|
23 |
+
"win_length": 2048,
|
24 |
+
"n_mel_channels": 128,
|
25 |
+
"mel_fmin": 0.0,
|
26 |
+
"mel_fmax": null
|
27 |
+
},
|
28 |
+
"model": {
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
+
"resblock_dilation_sizes": [
|
39 |
+
[1, 3, 5],
|
40 |
+
[1, 3, 5],
|
41 |
+
[1, 3, 5]
|
42 |
+
],
|
43 |
+
"upsample_rates": [12, 10, 2, 2],
|
44 |
+
"upsample_initial_channel": 512,
|
45 |
+
"upsample_kernel_sizes": [24, 20, 4, 4],
|
46 |
+
"use_spectral_norm": false,
|
47 |
+
"gin_channels": 256,
|
48 |
+
"spk_embed_dim": 109
|
49 |
+
}
|
50 |
+
}
|
configs/config.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"pth_path": "assets/weights/kikiV1.pth",
|
3 |
+
"index_path": "logs/kikiV1.index",
|
4 |
+
"sg_input_device": "VoiceMeeter Output (VB-Audio Vo (MME)",
|
5 |
+
"sg_output_device": "VoiceMeeter Aux Input (VB-Audio (MME)",
|
6 |
+
"threhold": -45.0,
|
7 |
+
"pitch": 12.0,
|
8 |
+
"index_rate": 0.0,
|
9 |
+
"rms_mix_rate": 0.0,
|
10 |
+
"block_time": 0.25,
|
11 |
+
"crossfade_length": 0.04,
|
12 |
+
"extra_time": 2.0,
|
13 |
+
"n_cpu": 6.0,
|
14 |
+
"f0method": "rmvpe"
|
15 |
+
}
|
configs/config.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import json
|
5 |
+
from multiprocessing import cpu_count
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
try:
|
10 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
11 |
+
if torch.xpu.is_available():
|
12 |
+
from infer.modules.ipex import ipex_init
|
13 |
+
ipex_init()
|
14 |
+
except Exception:
|
15 |
+
pass
|
16 |
+
|
17 |
+
import logging
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
version_config_list = [
|
23 |
+
"v1/32k.json",
|
24 |
+
"v1/40k.json",
|
25 |
+
"v1/48k.json",
|
26 |
+
"v2/48k.json",
|
27 |
+
"v2/32k.json",
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
def singleton_variable(func):
|
32 |
+
def wrapper(*args, **kwargs):
|
33 |
+
if not wrapper.instance:
|
34 |
+
wrapper.instance = func(*args, **kwargs)
|
35 |
+
return wrapper.instance
|
36 |
+
|
37 |
+
wrapper.instance = None
|
38 |
+
return wrapper
|
39 |
+
|
40 |
+
|
41 |
+
@singleton_variable
|
42 |
+
class Config:
|
43 |
+
def __init__(self):
|
44 |
+
self.device = "cuda:0"
|
45 |
+
self.is_half = True
|
46 |
+
self.n_cpu = 0
|
47 |
+
self.gpu_name = None
|
48 |
+
self.json_config = self.load_config_json()
|
49 |
+
self.gpu_mem = None
|
50 |
+
(
|
51 |
+
self.python_cmd,
|
52 |
+
self.listen_port,
|
53 |
+
self.iscolab,
|
54 |
+
self.noparallel,
|
55 |
+
self.noautoopen,
|
56 |
+
self.paperspace,
|
57 |
+
self.is_cli,
|
58 |
+
self.grtheme,
|
59 |
+
self.dml,
|
60 |
+
) = self.arg_parse()
|
61 |
+
self.instead = ""
|
62 |
+
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def load_config_json() -> dict:
|
66 |
+
d = {}
|
67 |
+
for config_file in version_config_list:
|
68 |
+
with open(f"configs/{config_file}", "r") as f:
|
69 |
+
d[config_file] = json.load(f)
|
70 |
+
return d
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def arg_parse() -> tuple:
|
74 |
+
exe = sys.executable or "python"
|
75 |
+
parser = argparse.ArgumentParser()
|
76 |
+
parser.add_argument("--port", type=int, default=7865, help="Listen port")
|
77 |
+
parser.add_argument("--pycmd", type=str, default=exe, help="Python command")
|
78 |
+
parser.add_argument("--colab", action="store_true", help="Launch in colab")
|
79 |
+
parser.add_argument(
|
80 |
+
"--noparallel", action="store_true", help="Disable parallel processing"
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--noautoopen",
|
84 |
+
action="store_true",
|
85 |
+
help="Do not open in browser automatically",
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--paperspace",
|
89 |
+
action="store_true",
|
90 |
+
help="Note that this argument just shares a gradio link for the web UI. Thus can be used on other non-local CLI systems.",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--is_cli",
|
94 |
+
action="store_true",
|
95 |
+
help="Use the CLI instead of setting up a gradio UI. This flag will launch an RVC text interface where you can execute functions from infer-web.py!",
|
96 |
+
)
|
97 |
+
|
98 |
+
parser.add_argument(
|
99 |
+
"-t",
|
100 |
+
"--theme",
|
101 |
+
help = "Theme for Gradio. Format - `JohnSmith9982/small_and_pretty` (no backticks)",
|
102 |
+
default = "JohnSmith9982/small_and_pretty",
|
103 |
+
type = str
|
104 |
+
)
|
105 |
+
|
106 |
+
parser.add_argument(
|
107 |
+
"--dml",
|
108 |
+
action="store_true",
|
109 |
+
help="Use DirectML backend instead of CUDA."
|
110 |
+
)
|
111 |
+
|
112 |
+
cmd_opts = parser.parse_args()
|
113 |
+
|
114 |
+
cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865
|
115 |
+
|
116 |
+
return (
|
117 |
+
cmd_opts.pycmd,
|
118 |
+
cmd_opts.port,
|
119 |
+
cmd_opts.colab,
|
120 |
+
cmd_opts.noparallel,
|
121 |
+
cmd_opts.noautoopen,
|
122 |
+
cmd_opts.paperspace,
|
123 |
+
cmd_opts.is_cli,
|
124 |
+
cmd_opts.theme,
|
125 |
+
cmd_opts.dml,
|
126 |
+
)
|
127 |
+
|
128 |
+
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
|
129 |
+
# check `getattr` and try it for compatibility
|
130 |
+
@staticmethod
|
131 |
+
def has_mps() -> bool:
|
132 |
+
if not torch.backends.mps.is_available():
|
133 |
+
return False
|
134 |
+
try:
|
135 |
+
torch.zeros(1).to(torch.device("mps"))
|
136 |
+
return True
|
137 |
+
except Exception:
|
138 |
+
return False
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def has_xpu() -> bool:
|
142 |
+
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
143 |
+
return True
|
144 |
+
else:
|
145 |
+
return False
|
146 |
+
|
147 |
+
def use_fp32_config(self):
|
148 |
+
for config_file in version_config_list:
|
149 |
+
self.json_config[config_file]["train"]["fp16_run"] = False
|
150 |
+
|
151 |
+
def device_config(self) -> tuple:
|
152 |
+
if torch.cuda.is_available():
|
153 |
+
if self.has_xpu():
|
154 |
+
self.device = self.instead = "xpu:0"
|
155 |
+
self.is_half = True
|
156 |
+
i_device = int(self.device.split(":")[-1])
|
157 |
+
self.gpu_name = torch.cuda.get_device_name(i_device)
|
158 |
+
if (
|
159 |
+
("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
|
160 |
+
or "P40" in self.gpu_name.upper()
|
161 |
+
or "P10" in self.gpu_name.upper()
|
162 |
+
or "1060" in self.gpu_name
|
163 |
+
or "1070" in self.gpu_name
|
164 |
+
or "1080" in self.gpu_name
|
165 |
+
):
|
166 |
+
logger.info("Found GPU %s, force to fp32", self.gpu_name)
|
167 |
+
self.is_half = False
|
168 |
+
self.use_fp32_config()
|
169 |
+
else:
|
170 |
+
logger.info("Found GPU %s", self.gpu_name)
|
171 |
+
self.gpu_mem = int(
|
172 |
+
torch.cuda.get_device_properties(i_device).total_memory
|
173 |
+
/ 1024
|
174 |
+
/ 1024
|
175 |
+
/ 1024
|
176 |
+
+ 0.4
|
177 |
+
)
|
178 |
+
if self.gpu_mem <= 4:
|
179 |
+
with open("infer/modules/train/preprocess.py", "r") as f:
|
180 |
+
strr = f.read().replace("3.7", "3.0")
|
181 |
+
with open("infer/modules/train/preprocess.py", "w") as f:
|
182 |
+
f.write(strr)
|
183 |
+
elif self.has_mps():
|
184 |
+
logger.info("No supported Nvidia GPU found")
|
185 |
+
self.device = self.instead = "mps"
|
186 |
+
self.is_half = False
|
187 |
+
self.use_fp32_config()
|
188 |
+
else:
|
189 |
+
logger.info("No supported Nvidia GPU found")
|
190 |
+
self.device = self.instead = "cpu"
|
191 |
+
self.is_half = False
|
192 |
+
self.use_fp32_config()
|
193 |
+
|
194 |
+
if self.n_cpu == 0:
|
195 |
+
self.n_cpu = cpu_count()
|
196 |
+
|
197 |
+
if self.is_half:
|
198 |
+
# 6G显存配置
|
199 |
+
x_pad = 3
|
200 |
+
x_query = 10
|
201 |
+
x_center = 60
|
202 |
+
x_max = 65
|
203 |
+
else:
|
204 |
+
# 5G显存配置
|
205 |
+
x_pad = 1
|
206 |
+
x_query = 6
|
207 |
+
x_center = 38
|
208 |
+
x_max = 41
|
209 |
+
|
210 |
+
if self.gpu_mem is not None and self.gpu_mem <= 4:
|
211 |
+
x_pad = 1
|
212 |
+
x_query = 5
|
213 |
+
x_center = 30
|
214 |
+
x_max = 32
|
215 |
+
if self.dml:
|
216 |
+
logger.info("Use DirectML instead")
|
217 |
+
if (
|
218 |
+
os.path.exists(
|
219 |
+
"runtime\Lib\site-packages\onnxruntime\capi\DirectML.dll"
|
220 |
+
)
|
221 |
+
== False
|
222 |
+
):
|
223 |
+
try:
|
224 |
+
os.rename(
|
225 |
+
"runtime\Lib\site-packages\onnxruntime",
|
226 |
+
"runtime\Lib\site-packages\onnxruntime-cuda",
|
227 |
+
)
|
228 |
+
except:
|
229 |
+
pass
|
230 |
+
try:
|
231 |
+
os.rename(
|
232 |
+
"runtime\Lib\site-packages\onnxruntime-dml",
|
233 |
+
"runtime\Lib\site-packages\onnxruntime",
|
234 |
+
)
|
235 |
+
except:
|
236 |
+
pass
|
237 |
+
# if self.device != "cpu":
|
238 |
+
import torch_directml
|
239 |
+
|
240 |
+
self.device = torch_directml.device(torch_directml.default_device())
|
241 |
+
self.is_half = False
|
242 |
+
else:
|
243 |
+
if self.instead:
|
244 |
+
logger.info(f"Use {self.instead} instead")
|
245 |
+
if (
|
246 |
+
os.path.exists(
|
247 |
+
"runtime\Lib\site-packages\onnxruntime\capi\onnxruntime_providers_cuda.dll"
|
248 |
+
)
|
249 |
+
== False
|
250 |
+
):
|
251 |
+
try:
|
252 |
+
os.rename(
|
253 |
+
"runtime\Lib\site-packages\onnxruntime",
|
254 |
+
"runtime\Lib\site-packages\onnxruntime-dml",
|
255 |
+
)
|
256 |
+
except:
|
257 |
+
pass
|
258 |
+
try:
|
259 |
+
os.rename(
|
260 |
+
"runtime\Lib\site-packages\onnxruntime-cuda",
|
261 |
+
"runtime\Lib\site-packages\onnxruntime",
|
262 |
+
)
|
263 |
+
except:
|
264 |
+
pass
|
265 |
+
return x_pad, x_query, x_center, x_max
|
configs/v1/32k.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"seed": 1234,
|
5 |
+
"epochs": 20000,
|
6 |
+
"learning_rate": 1e-4,
|
7 |
+
"betas": [0.8, 0.99],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"batch_size": 4,
|
10 |
+
"fp16_run": true,
|
11 |
+
"lr_decay": 0.999875,
|
12 |
+
"segment_size": 12800,
|
13 |
+
"init_lr_ratio": 1,
|
14 |
+
"warmup_epochs": 0,
|
15 |
+
"c_mel": 45,
|
16 |
+
"c_kl": 1.0
|
17 |
+
},
|
18 |
+
"data": {
|
19 |
+
"max_wav_value": 32768.0,
|
20 |
+
"sampling_rate": 32000,
|
21 |
+
"filter_length": 1024,
|
22 |
+
"hop_length": 320,
|
23 |
+
"win_length": 1024,
|
24 |
+
"n_mel_channels": 80,
|
25 |
+
"mel_fmin": 0.0,
|
26 |
+
"mel_fmax": null
|
27 |
+
},
|
28 |
+
"model": {
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [3,7,11],
|
38 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
39 |
+
"upsample_rates": [10,4,2,2,2],
|
40 |
+
"upsample_initial_channel": 512,
|
41 |
+
"upsample_kernel_sizes": [16,16,4,4,4],
|
42 |
+
"use_spectral_norm": false,
|
43 |
+
"gin_channels": 256,
|
44 |
+
"spk_embed_dim": 109
|
45 |
+
}
|
46 |
+
}
|
configs/v1/40k.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"seed": 1234,
|
5 |
+
"epochs": 20000,
|
6 |
+
"learning_rate": 1e-4,
|
7 |
+
"betas": [0.8, 0.99],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"batch_size": 4,
|
10 |
+
"fp16_run": true,
|
11 |
+
"lr_decay": 0.999875,
|
12 |
+
"segment_size": 12800,
|
13 |
+
"init_lr_ratio": 1,
|
14 |
+
"warmup_epochs": 0,
|
15 |
+
"c_mel": 45,
|
16 |
+
"c_kl": 1.0
|
17 |
+
},
|
18 |
+
"data": {
|
19 |
+
"max_wav_value": 32768.0,
|
20 |
+
"sampling_rate": 40000,
|
21 |
+
"filter_length": 2048,
|
22 |
+
"hop_length": 400,
|
23 |
+
"win_length": 2048,
|
24 |
+
"n_mel_channels": 125,
|
25 |
+
"mel_fmin": 0.0,
|
26 |
+
"mel_fmax": null
|
27 |
+
},
|
28 |
+
"model": {
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [3,7,11],
|
38 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
39 |
+
"upsample_rates": [10,10,2,2],
|
40 |
+
"upsample_initial_channel": 512,
|
41 |
+
"upsample_kernel_sizes": [16,16,4,4],
|
42 |
+
"use_spectral_norm": false,
|
43 |
+
"gin_channels": 256,
|
44 |
+
"spk_embed_dim": 109
|
45 |
+
}
|
46 |
+
}
|
configs/v1/48k.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"seed": 1234,
|
5 |
+
"epochs": 20000,
|
6 |
+
"learning_rate": 1e-4,
|
7 |
+
"betas": [0.8, 0.99],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"batch_size": 4,
|
10 |
+
"fp16_run": true,
|
11 |
+
"lr_decay": 0.999875,
|
12 |
+
"segment_size": 11520,
|
13 |
+
"init_lr_ratio": 1,
|
14 |
+
"warmup_epochs": 0,
|
15 |
+
"c_mel": 45,
|
16 |
+
"c_kl": 1.0
|
17 |
+
},
|
18 |
+
"data": {
|
19 |
+
"max_wav_value": 32768.0,
|
20 |
+
"sampling_rate": 48000,
|
21 |
+
"filter_length": 2048,
|
22 |
+
"hop_length": 480,
|
23 |
+
"win_length": 2048,
|
24 |
+
"n_mel_channels": 128,
|
25 |
+
"mel_fmin": 0.0,
|
26 |
+
"mel_fmax": null
|
27 |
+
},
|
28 |
+
"model": {
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [3,7,11],
|
38 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
39 |
+
"upsample_rates": [10,6,2,2,2],
|
40 |
+
"upsample_initial_channel": 512,
|
41 |
+
"upsample_kernel_sizes": [16,16,4,4,4],
|
42 |
+
"use_spectral_norm": false,
|
43 |
+
"gin_channels": 256,
|
44 |
+
"spk_embed_dim": 109
|
45 |
+
}
|
46 |
+
}
|
configs/v2/32k.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"seed": 1234,
|
5 |
+
"epochs": 20000,
|
6 |
+
"learning_rate": 1e-4,
|
7 |
+
"betas": [0.8, 0.99],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"batch_size": 4,
|
10 |
+
"fp16_run": true,
|
11 |
+
"lr_decay": 0.999875,
|
12 |
+
"segment_size": 12800,
|
13 |
+
"init_lr_ratio": 1,
|
14 |
+
"warmup_epochs": 0,
|
15 |
+
"c_mel": 45,
|
16 |
+
"c_kl": 1.0
|
17 |
+
},
|
18 |
+
"data": {
|
19 |
+
"max_wav_value": 32768.0,
|
20 |
+
"sampling_rate": 32000,
|
21 |
+
"filter_length": 1024,
|
22 |
+
"hop_length": 320,
|
23 |
+
"win_length": 1024,
|
24 |
+
"n_mel_channels": 80,
|
25 |
+
"mel_fmin": 0.0,
|
26 |
+
"mel_fmax": null
|
27 |
+
},
|
28 |
+
"model": {
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [3,7,11],
|
38 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
39 |
+
"upsample_rates": [10,8,2,2],
|
40 |
+
"upsample_initial_channel": 512,
|
41 |
+
"upsample_kernel_sizes": [20,16,4,4],
|
42 |
+
"use_spectral_norm": false,
|
43 |
+
"gin_channels": 256,
|
44 |
+
"spk_embed_dim": 109
|
45 |
+
}
|
46 |
+
}
|
configs/v2/48k.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"seed": 1234,
|
5 |
+
"epochs": 20000,
|
6 |
+
"learning_rate": 1e-4,
|
7 |
+
"betas": [0.8, 0.99],
|
8 |
+
"eps": 1e-9,
|
9 |
+
"batch_size": 4,
|
10 |
+
"fp16_run": true,
|
11 |
+
"lr_decay": 0.999875,
|
12 |
+
"segment_size": 17280,
|
13 |
+
"init_lr_ratio": 1,
|
14 |
+
"warmup_epochs": 0,
|
15 |
+
"c_mel": 45,
|
16 |
+
"c_kl": 1.0
|
17 |
+
},
|
18 |
+
"data": {
|
19 |
+
"max_wav_value": 32768.0,
|
20 |
+
"sampling_rate": 48000,
|
21 |
+
"filter_length": 2048,
|
22 |
+
"hop_length": 480,
|
23 |
+
"win_length": 2048,
|
24 |
+
"n_mel_channels": 128,
|
25 |
+
"mel_fmin": 0.0,
|
26 |
+
"mel_fmax": null
|
27 |
+
},
|
28 |
+
"model": {
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [3,7,11],
|
38 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
39 |
+
"upsample_rates": [12,10,2,2],
|
40 |
+
"upsample_initial_channel": 512,
|
41 |
+
"upsample_kernel_sizes": [24,20,4,4],
|
42 |
+
"use_spectral_norm": false,
|
43 |
+
"gin_channels": 256,
|
44 |
+
"spk_embed_dim": 109
|
45 |
+
}
|
46 |
+
}
|
csvdb/formanting.csv
ADDED
File without changes
|
csvdb/stop.csv
ADDED
File without changes
|
demucs/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
__version__ = "2.0.3"
|
demucs/__main__.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import time
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
|
14 |
+
import torch as th
|
15 |
+
from torch import distributed, nn
|
16 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
17 |
+
|
18 |
+
from .augment import FlipChannels, FlipSign, Remix, Scale, Shift
|
19 |
+
from .compressed import get_compressed_datasets
|
20 |
+
from .model import Demucs
|
21 |
+
from .parser import get_name, get_parser
|
22 |
+
from .raw import Rawset
|
23 |
+
from .repitch import RepitchedWrapper
|
24 |
+
from .pretrained import load_pretrained, SOURCES
|
25 |
+
from .tasnet import ConvTasNet
|
26 |
+
from .test import evaluate
|
27 |
+
from .train import train_model, validate_model
|
28 |
+
from .utils import (human_seconds, load_model, save_model, get_state,
|
29 |
+
save_state, sizeof_fmt, get_quantizer)
|
30 |
+
from .wav import get_wav_datasets, get_musdb_wav_datasets
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class SavedState:
|
35 |
+
metrics: list = field(default_factory=list)
|
36 |
+
last_state: dict = None
|
37 |
+
best_state: dict = None
|
38 |
+
optimizer: dict = None
|
39 |
+
|
40 |
+
|
41 |
+
def main():
|
42 |
+
parser = get_parser()
|
43 |
+
args = parser.parse_args()
|
44 |
+
name = get_name(parser, args)
|
45 |
+
print(f"Experiment {name}")
|
46 |
+
|
47 |
+
if args.musdb is None and args.rank == 0:
|
48 |
+
print(
|
49 |
+
"You must provide the path to the MusDB dataset with the --musdb flag. "
|
50 |
+
"To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
|
51 |
+
file=sys.stderr)
|
52 |
+
sys.exit(1)
|
53 |
+
|
54 |
+
eval_folder = args.evals / name
|
55 |
+
eval_folder.mkdir(exist_ok=True, parents=True)
|
56 |
+
args.logs.mkdir(exist_ok=True)
|
57 |
+
metrics_path = args.logs / f"{name}.json"
|
58 |
+
eval_folder.mkdir(exist_ok=True, parents=True)
|
59 |
+
args.checkpoints.mkdir(exist_ok=True, parents=True)
|
60 |
+
args.models.mkdir(exist_ok=True, parents=True)
|
61 |
+
|
62 |
+
if args.device is None:
|
63 |
+
device = "cpu"
|
64 |
+
if th.cuda.is_available():
|
65 |
+
device = "cuda"
|
66 |
+
else:
|
67 |
+
device = args.device
|
68 |
+
|
69 |
+
th.manual_seed(args.seed)
|
70 |
+
# Prevents too many threads to be started when running `museval` as it can be quite
|
71 |
+
# inefficient on NUMA architectures.
|
72 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
73 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
74 |
+
|
75 |
+
if args.world_size > 1:
|
76 |
+
if device != "cuda" and args.rank == 0:
|
77 |
+
print("Error: distributed training is only available with cuda device", file=sys.stderr)
|
78 |
+
sys.exit(1)
|
79 |
+
th.cuda.set_device(args.rank % th.cuda.device_count())
|
80 |
+
distributed.init_process_group(backend="nccl",
|
81 |
+
init_method="tcp://" + args.master,
|
82 |
+
rank=args.rank,
|
83 |
+
world_size=args.world_size)
|
84 |
+
|
85 |
+
checkpoint = args.checkpoints / f"{name}.th"
|
86 |
+
checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
|
87 |
+
if args.restart and checkpoint.exists() and args.rank == 0:
|
88 |
+
checkpoint.unlink()
|
89 |
+
|
90 |
+
if args.test or args.test_pretrained:
|
91 |
+
args.epochs = 1
|
92 |
+
args.repeat = 0
|
93 |
+
if args.test:
|
94 |
+
model = load_model(args.models / args.test)
|
95 |
+
else:
|
96 |
+
model = load_pretrained(args.test_pretrained)
|
97 |
+
elif args.tasnet:
|
98 |
+
model = ConvTasNet(audio_channels=args.audio_channels,
|
99 |
+
samplerate=args.samplerate, X=args.X,
|
100 |
+
segment_length=4 * args.samples,
|
101 |
+
sources=SOURCES)
|
102 |
+
else:
|
103 |
+
model = Demucs(
|
104 |
+
audio_channels=args.audio_channels,
|
105 |
+
channels=args.channels,
|
106 |
+
context=args.context,
|
107 |
+
depth=args.depth,
|
108 |
+
glu=args.glu,
|
109 |
+
growth=args.growth,
|
110 |
+
kernel_size=args.kernel_size,
|
111 |
+
lstm_layers=args.lstm_layers,
|
112 |
+
rescale=args.rescale,
|
113 |
+
rewrite=args.rewrite,
|
114 |
+
stride=args.conv_stride,
|
115 |
+
resample=args.resample,
|
116 |
+
normalize=args.normalize,
|
117 |
+
samplerate=args.samplerate,
|
118 |
+
segment_length=4 * args.samples,
|
119 |
+
sources=SOURCES,
|
120 |
+
)
|
121 |
+
model.to(device)
|
122 |
+
if args.init:
|
123 |
+
model.load_state_dict(load_pretrained(args.init).state_dict())
|
124 |
+
|
125 |
+
if args.show:
|
126 |
+
print(model)
|
127 |
+
size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
|
128 |
+
print(f"Model size {size}")
|
129 |
+
return
|
130 |
+
|
131 |
+
try:
|
132 |
+
saved = th.load(checkpoint, map_location='cpu')
|
133 |
+
except IOError:
|
134 |
+
saved = SavedState()
|
135 |
+
|
136 |
+
optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
|
137 |
+
|
138 |
+
quantizer = None
|
139 |
+
quantizer = get_quantizer(model, args, optimizer)
|
140 |
+
|
141 |
+
if saved.last_state is not None:
|
142 |
+
model.load_state_dict(saved.last_state, strict=False)
|
143 |
+
if saved.optimizer is not None:
|
144 |
+
optimizer.load_state_dict(saved.optimizer)
|
145 |
+
|
146 |
+
model_name = f"{name}.th"
|
147 |
+
if args.save_model:
|
148 |
+
if args.rank == 0:
|
149 |
+
model.to("cpu")
|
150 |
+
model.load_state_dict(saved.best_state)
|
151 |
+
save_model(model, quantizer, args, args.models / model_name)
|
152 |
+
return
|
153 |
+
elif args.save_state:
|
154 |
+
model_name = f"{args.save_state}.th"
|
155 |
+
if args.rank == 0:
|
156 |
+
model.to("cpu")
|
157 |
+
model.load_state_dict(saved.best_state)
|
158 |
+
state = get_state(model, quantizer)
|
159 |
+
save_state(state, args.models / model_name)
|
160 |
+
return
|
161 |
+
|
162 |
+
if args.rank == 0:
|
163 |
+
done = args.logs / f"{name}.done"
|
164 |
+
if done.exists():
|
165 |
+
done.unlink()
|
166 |
+
|
167 |
+
augment = [Shift(args.data_stride)]
|
168 |
+
if args.augment:
|
169 |
+
augment += [FlipSign(), FlipChannels(), Scale(),
|
170 |
+
Remix(group_size=args.remix_group_size)]
|
171 |
+
augment = nn.Sequential(*augment).to(device)
|
172 |
+
print("Agumentation pipeline:", augment)
|
173 |
+
|
174 |
+
if args.mse:
|
175 |
+
criterion = nn.MSELoss()
|
176 |
+
else:
|
177 |
+
criterion = nn.L1Loss()
|
178 |
+
|
179 |
+
# Setting number of samples so that all convolution windows are full.
|
180 |
+
# Prevents hard to debug mistake with the prediction being shifted compared
|
181 |
+
# to the input mixture.
|
182 |
+
samples = model.valid_length(args.samples)
|
183 |
+
print(f"Number of training samples adjusted to {samples}")
|
184 |
+
samples = samples + args.data_stride
|
185 |
+
if args.repitch:
|
186 |
+
# We need a bit more audio samples, to account for potential
|
187 |
+
# tempo change.
|
188 |
+
samples = math.ceil(samples / (1 - 0.01 * args.max_tempo))
|
189 |
+
|
190 |
+
args.metadata.mkdir(exist_ok=True, parents=True)
|
191 |
+
if args.raw:
|
192 |
+
train_set = Rawset(args.raw / "train",
|
193 |
+
samples=samples,
|
194 |
+
channels=args.audio_channels,
|
195 |
+
streams=range(1, len(model.sources) + 1),
|
196 |
+
stride=args.data_stride)
|
197 |
+
|
198 |
+
valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
|
199 |
+
elif args.wav:
|
200 |
+
train_set, valid_set = get_wav_datasets(args, samples, model.sources)
|
201 |
+
elif args.is_wav:
|
202 |
+
train_set, valid_set = get_musdb_wav_datasets(args, samples, model.sources)
|
203 |
+
else:
|
204 |
+
train_set, valid_set = get_compressed_datasets(args, samples)
|
205 |
+
|
206 |
+
if args.repitch:
|
207 |
+
train_set = RepitchedWrapper(
|
208 |
+
train_set,
|
209 |
+
proba=args.repitch,
|
210 |
+
max_tempo=args.max_tempo)
|
211 |
+
|
212 |
+
best_loss = float("inf")
|
213 |
+
for epoch, metrics in enumerate(saved.metrics):
|
214 |
+
print(f"Epoch {epoch:03d}: "
|
215 |
+
f"train={metrics['train']:.8f} "
|
216 |
+
f"valid={metrics['valid']:.8f} "
|
217 |
+
f"best={metrics['best']:.4f} "
|
218 |
+
f"ms={metrics.get('true_model_size', 0):.2f}MB "
|
219 |
+
f"cms={metrics.get('compressed_model_size', 0):.2f}MB "
|
220 |
+
f"duration={human_seconds(metrics['duration'])}")
|
221 |
+
best_loss = metrics['best']
|
222 |
+
|
223 |
+
if args.world_size > 1:
|
224 |
+
dmodel = DistributedDataParallel(model,
|
225 |
+
device_ids=[th.cuda.current_device()],
|
226 |
+
output_device=th.cuda.current_device())
|
227 |
+
else:
|
228 |
+
dmodel = model
|
229 |
+
|
230 |
+
for epoch in range(len(saved.metrics), args.epochs):
|
231 |
+
begin = time.time()
|
232 |
+
model.train()
|
233 |
+
train_loss, model_size = train_model(
|
234 |
+
epoch, train_set, dmodel, criterion, optimizer, augment,
|
235 |
+
quantizer=quantizer,
|
236 |
+
batch_size=args.batch_size,
|
237 |
+
device=device,
|
238 |
+
repeat=args.repeat,
|
239 |
+
seed=args.seed,
|
240 |
+
diffq=args.diffq,
|
241 |
+
workers=args.workers,
|
242 |
+
world_size=args.world_size)
|
243 |
+
model.eval()
|
244 |
+
valid_loss = validate_model(
|
245 |
+
epoch, valid_set, model, criterion,
|
246 |
+
device=device,
|
247 |
+
rank=args.rank,
|
248 |
+
split=args.split_valid,
|
249 |
+
overlap=args.overlap,
|
250 |
+
world_size=args.world_size)
|
251 |
+
|
252 |
+
ms = 0
|
253 |
+
cms = 0
|
254 |
+
if quantizer and args.rank == 0:
|
255 |
+
ms = quantizer.true_model_size()
|
256 |
+
cms = quantizer.compressed_model_size(num_workers=min(40, args.world_size * 10))
|
257 |
+
|
258 |
+
duration = time.time() - begin
|
259 |
+
if valid_loss < best_loss and ms <= args.ms_target:
|
260 |
+
best_loss = valid_loss
|
261 |
+
saved.best_state = {
|
262 |
+
key: value.to("cpu").clone()
|
263 |
+
for key, value in model.state_dict().items()
|
264 |
+
}
|
265 |
+
|
266 |
+
saved.metrics.append({
|
267 |
+
"train": train_loss,
|
268 |
+
"valid": valid_loss,
|
269 |
+
"best": best_loss,
|
270 |
+
"duration": duration,
|
271 |
+
"model_size": model_size,
|
272 |
+
"true_model_size": ms,
|
273 |
+
"compressed_model_size": cms,
|
274 |
+
})
|
275 |
+
if args.rank == 0:
|
276 |
+
json.dump(saved.metrics, open(metrics_path, "w"))
|
277 |
+
|
278 |
+
saved.last_state = model.state_dict()
|
279 |
+
saved.optimizer = optimizer.state_dict()
|
280 |
+
if args.rank == 0 and not args.test:
|
281 |
+
th.save(saved, checkpoint_tmp)
|
282 |
+
checkpoint_tmp.rename(checkpoint)
|
283 |
+
|
284 |
+
print(f"Epoch {epoch:03d}: "
|
285 |
+
f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} ms={ms:.2f}MB "
|
286 |
+
f"cms={cms:.2f}MB "
|
287 |
+
f"duration={human_seconds(duration)}")
|
288 |
+
|
289 |
+
if args.world_size > 1:
|
290 |
+
distributed.barrier()
|
291 |
+
|
292 |
+
del dmodel
|
293 |
+
model.load_state_dict(saved.best_state)
|
294 |
+
if args.eval_cpu:
|
295 |
+
device = "cpu"
|
296 |
+
model.to(device)
|
297 |
+
model.eval()
|
298 |
+
evaluate(model, args.musdb, eval_folder,
|
299 |
+
is_wav=args.is_wav,
|
300 |
+
rank=args.rank,
|
301 |
+
world_size=args.world_size,
|
302 |
+
device=device,
|
303 |
+
save=args.save,
|
304 |
+
split=args.split_valid,
|
305 |
+
shifts=args.shifts,
|
306 |
+
overlap=args.overlap,
|
307 |
+
workers=args.eval_workers)
|
308 |
+
model.to("cpu")
|
309 |
+
if args.rank == 0:
|
310 |
+
if not (args.test or args.test_pretrained):
|
311 |
+
save_model(model, quantizer, args, args.models / model_name)
|
312 |
+
print("done")
|
313 |
+
done.write_text("done")
|
314 |
+
|
315 |
+
|
316 |
+
if __name__ == "__main__":
|
317 |
+
main()
|
demucs/audio.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import json
|
7 |
+
import subprocess as sp
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import julius
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from .utils import temp_filenames
|
15 |
+
|
16 |
+
|
17 |
+
def _read_info(path):
|
18 |
+
stdout_data = sp.check_output([
|
19 |
+
'ffprobe', "-loglevel", "panic",
|
20 |
+
str(path), '-print_format', 'json', '-show_format', '-show_streams'
|
21 |
+
])
|
22 |
+
return json.loads(stdout_data.decode('utf-8'))
|
23 |
+
|
24 |
+
|
25 |
+
class AudioFile:
|
26 |
+
"""
|
27 |
+
Allows to read audio from any format supported by ffmpeg, as well as resampling or
|
28 |
+
converting to mono on the fly. See :method:`read` for more details.
|
29 |
+
"""
|
30 |
+
def __init__(self, path: Path):
|
31 |
+
self.path = Path(path)
|
32 |
+
self._info = None
|
33 |
+
|
34 |
+
def __repr__(self):
|
35 |
+
features = [("path", self.path)]
|
36 |
+
features.append(("samplerate", self.samplerate()))
|
37 |
+
features.append(("channels", self.channels()))
|
38 |
+
features.append(("streams", len(self)))
|
39 |
+
features_str = ", ".join(f"{name}={value}" for name, value in features)
|
40 |
+
return f"AudioFile({features_str})"
|
41 |
+
|
42 |
+
@property
|
43 |
+
def info(self):
|
44 |
+
if self._info is None:
|
45 |
+
self._info = _read_info(self.path)
|
46 |
+
return self._info
|
47 |
+
|
48 |
+
@property
|
49 |
+
def duration(self):
|
50 |
+
return float(self.info['format']['duration'])
|
51 |
+
|
52 |
+
@property
|
53 |
+
def _audio_streams(self):
|
54 |
+
return [
|
55 |
+
index for index, stream in enumerate(self.info["streams"])
|
56 |
+
if stream["codec_type"] == "audio"
|
57 |
+
]
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return len(self._audio_streams)
|
61 |
+
|
62 |
+
def channels(self, stream=0):
|
63 |
+
return int(self.info['streams'][self._audio_streams[stream]]['channels'])
|
64 |
+
|
65 |
+
def samplerate(self, stream=0):
|
66 |
+
return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
|
67 |
+
|
68 |
+
def read(self,
|
69 |
+
seek_time=None,
|
70 |
+
duration=None,
|
71 |
+
streams=slice(None),
|
72 |
+
samplerate=None,
|
73 |
+
channels=None,
|
74 |
+
temp_folder=None):
|
75 |
+
"""
|
76 |
+
Slightly more efficient implementation than stempeg,
|
77 |
+
in particular, this will extract all stems at once
|
78 |
+
rather than having to loop over one file multiple times
|
79 |
+
for each stream.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
seek_time (float): seek time in seconds or None if no seeking is needed.
|
83 |
+
duration (float): duration in seconds to extract or None to extract until the end.
|
84 |
+
streams (slice, int or list): streams to extract, can be a single int, a list or
|
85 |
+
a slice. If it is a slice or list, the output will be of size [S, C, T]
|
86 |
+
with S the number of streams, C the number of channels and T the number of samples.
|
87 |
+
If it is an int, the output will be [C, T].
|
88 |
+
samplerate (int): if provided, will resample on the fly. If None, no resampling will
|
89 |
+
be done. Original sampling rate can be obtained with :method:`samplerate`.
|
90 |
+
channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
|
91 |
+
as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
|
92 |
+
See https://sound.stackexchange.com/a/42710.
|
93 |
+
Our definition of mono is simply the average of the two channels. Any other
|
94 |
+
value will be ignored.
|
95 |
+
temp_folder (str or Path or None): temporary folder to use for decoding.
|
96 |
+
|
97 |
+
|
98 |
+
"""
|
99 |
+
streams = np.array(range(len(self)))[streams]
|
100 |
+
single = not isinstance(streams, np.ndarray)
|
101 |
+
if single:
|
102 |
+
streams = [streams]
|
103 |
+
|
104 |
+
if duration is None:
|
105 |
+
target_size = None
|
106 |
+
query_duration = None
|
107 |
+
else:
|
108 |
+
target_size = int((samplerate or self.samplerate()) * duration)
|
109 |
+
query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
|
110 |
+
|
111 |
+
with temp_filenames(len(streams)) as filenames:
|
112 |
+
command = ['ffmpeg', '-y']
|
113 |
+
command += ['-loglevel', 'panic']
|
114 |
+
if seek_time:
|
115 |
+
command += ['-ss', str(seek_time)]
|
116 |
+
command += ['-i', str(self.path)]
|
117 |
+
for stream, filename in zip(streams, filenames):
|
118 |
+
command += ['-map', f'0:{self._audio_streams[stream]}']
|
119 |
+
if query_duration is not None:
|
120 |
+
command += ['-t', str(query_duration)]
|
121 |
+
command += ['-threads', '1']
|
122 |
+
command += ['-f', 'f32le']
|
123 |
+
if samplerate is not None:
|
124 |
+
command += ['-ar', str(samplerate)]
|
125 |
+
command += [filename]
|
126 |
+
|
127 |
+
sp.run(command, check=True)
|
128 |
+
wavs = []
|
129 |
+
for filename in filenames:
|
130 |
+
wav = np.fromfile(filename, dtype=np.float32)
|
131 |
+
wav = torch.from_numpy(wav)
|
132 |
+
wav = wav.view(-1, self.channels()).t()
|
133 |
+
if channels is not None:
|
134 |
+
wav = convert_audio_channels(wav, channels)
|
135 |
+
if target_size is not None:
|
136 |
+
wav = wav[..., :target_size]
|
137 |
+
wavs.append(wav)
|
138 |
+
wav = torch.stack(wavs, dim=0)
|
139 |
+
if single:
|
140 |
+
wav = wav[0]
|
141 |
+
return wav
|
142 |
+
|
143 |
+
|
144 |
+
def convert_audio_channels(wav, channels=2):
|
145 |
+
"""Convert audio to the given number of channels."""
|
146 |
+
*shape, src_channels, length = wav.shape
|
147 |
+
if src_channels == channels:
|
148 |
+
pass
|
149 |
+
elif channels == 1:
|
150 |
+
# Case 1:
|
151 |
+
# The caller asked 1-channel audio, but the stream have multiple
|
152 |
+
# channels, downmix all channels.
|
153 |
+
wav = wav.mean(dim=-2, keepdim=True)
|
154 |
+
elif src_channels == 1:
|
155 |
+
# Case 2:
|
156 |
+
# The caller asked for multiple channels, but the input file have
|
157 |
+
# one single channel, replicate the audio over all channels.
|
158 |
+
wav = wav.expand(*shape, channels, length)
|
159 |
+
elif src_channels >= channels:
|
160 |
+
# Case 3:
|
161 |
+
# The caller asked for multiple channels, and the input file have
|
162 |
+
# more channels than requested. In that case return the first channels.
|
163 |
+
wav = wav[..., :channels, :]
|
164 |
+
else:
|
165 |
+
# Case 4: What is a reasonable choice here?
|
166 |
+
raise ValueError('The audio file has less channels than requested but is not mono.')
|
167 |
+
return wav
|
168 |
+
|
169 |
+
|
170 |
+
def convert_audio(wav, from_samplerate, to_samplerate, channels):
|
171 |
+
wav = convert_audio_channels(wav, channels)
|
172 |
+
return julius.resample_frac(wav, from_samplerate, to_samplerate)
|
demucs/augment.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import random
|
8 |
+
import torch as th
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
class Shift(nn.Module):
|
13 |
+
"""
|
14 |
+
Randomly shift audio in time by up to `shift` samples.
|
15 |
+
"""
|
16 |
+
def __init__(self, shift=8192):
|
17 |
+
super().__init__()
|
18 |
+
self.shift = shift
|
19 |
+
|
20 |
+
def forward(self, wav):
|
21 |
+
batch, sources, channels, time = wav.size()
|
22 |
+
length = time - self.shift
|
23 |
+
if self.shift > 0:
|
24 |
+
if not self.training:
|
25 |
+
wav = wav[..., :length]
|
26 |
+
else:
|
27 |
+
offsets = th.randint(self.shift, [batch, sources, 1, 1], device=wav.device)
|
28 |
+
offsets = offsets.expand(-1, -1, channels, -1)
|
29 |
+
indexes = th.arange(length, device=wav.device)
|
30 |
+
wav = wav.gather(3, indexes + offsets)
|
31 |
+
return wav
|
32 |
+
|
33 |
+
|
34 |
+
class FlipChannels(nn.Module):
|
35 |
+
"""
|
36 |
+
Flip left-right channels.
|
37 |
+
"""
|
38 |
+
def forward(self, wav):
|
39 |
+
batch, sources, channels, time = wav.size()
|
40 |
+
if self.training and wav.size(2) == 2:
|
41 |
+
left = th.randint(2, (batch, sources, 1, 1), device=wav.device)
|
42 |
+
left = left.expand(-1, -1, -1, time)
|
43 |
+
right = 1 - left
|
44 |
+
wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2)
|
45 |
+
return wav
|
46 |
+
|
47 |
+
|
48 |
+
class FlipSign(nn.Module):
|
49 |
+
"""
|
50 |
+
Random sign flip.
|
51 |
+
"""
|
52 |
+
def forward(self, wav):
|
53 |
+
batch, sources, channels, time = wav.size()
|
54 |
+
if self.training:
|
55 |
+
signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32)
|
56 |
+
wav = wav * (2 * signs - 1)
|
57 |
+
return wav
|
58 |
+
|
59 |
+
|
60 |
+
class Remix(nn.Module):
|
61 |
+
"""
|
62 |
+
Shuffle sources to make new mixes.
|
63 |
+
"""
|
64 |
+
def __init__(self, group_size=4):
|
65 |
+
"""
|
66 |
+
Shuffle sources within one batch.
|
67 |
+
Each batch is divided into groups of size `group_size` and shuffling is done within
|
68 |
+
each group separatly. This allow to keep the same probability distribution no matter
|
69 |
+
the number of GPUs. Without this grouping, using more GPUs would lead to a higher
|
70 |
+
probability of keeping two sources from the same track together which can impact
|
71 |
+
performance.
|
72 |
+
"""
|
73 |
+
super().__init__()
|
74 |
+
self.group_size = group_size
|
75 |
+
|
76 |
+
def forward(self, wav):
|
77 |
+
batch, streams, channels, time = wav.size()
|
78 |
+
device = wav.device
|
79 |
+
|
80 |
+
if self.training:
|
81 |
+
group_size = self.group_size or batch
|
82 |
+
if batch % group_size != 0:
|
83 |
+
raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}")
|
84 |
+
groups = batch // group_size
|
85 |
+
wav = wav.view(groups, group_size, streams, channels, time)
|
86 |
+
permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device),
|
87 |
+
dim=1)
|
88 |
+
wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time))
|
89 |
+
wav = wav.view(batch, streams, channels, time)
|
90 |
+
return wav
|
91 |
+
|
92 |
+
|
93 |
+
class Scale(nn.Module):
|
94 |
+
def __init__(self, proba=1., min=0.25, max=1.25):
|
95 |
+
super().__init__()
|
96 |
+
self.proba = proba
|
97 |
+
self.min = min
|
98 |
+
self.max = max
|
99 |
+
|
100 |
+
def forward(self, wav):
|
101 |
+
batch, streams, channels, time = wav.size()
|
102 |
+
device = wav.device
|
103 |
+
if self.training and random.random() < self.proba:
|
104 |
+
scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max)
|
105 |
+
wav *= scales
|
106 |
+
return wav
|
demucs/compressed.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
from fractions import Fraction
|
9 |
+
from concurrent import futures
|
10 |
+
|
11 |
+
import musdb
|
12 |
+
from torch import distributed
|
13 |
+
|
14 |
+
from .audio import AudioFile
|
15 |
+
|
16 |
+
|
17 |
+
def get_musdb_tracks(root, *args, **kwargs):
|
18 |
+
mus = musdb.DB(root, *args, **kwargs)
|
19 |
+
return {track.name: track.path for track in mus}
|
20 |
+
|
21 |
+
|
22 |
+
class StemsSet:
|
23 |
+
def __init__(self, tracks, metadata, duration=None, stride=1,
|
24 |
+
samplerate=44100, channels=2, streams=slice(None)):
|
25 |
+
|
26 |
+
self.metadata = []
|
27 |
+
for name, path in tracks.items():
|
28 |
+
meta = dict(metadata[name])
|
29 |
+
meta["path"] = path
|
30 |
+
meta["name"] = name
|
31 |
+
self.metadata.append(meta)
|
32 |
+
if duration is not None and meta["duration"] < duration:
|
33 |
+
raise ValueError(f"Track {name} duration is too small {meta['duration']}")
|
34 |
+
self.metadata.sort(key=lambda x: x["name"])
|
35 |
+
self.duration = duration
|
36 |
+
self.stride = stride
|
37 |
+
self.channels = channels
|
38 |
+
self.samplerate = samplerate
|
39 |
+
self.streams = streams
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return sum(self._examples_count(m) for m in self.metadata)
|
43 |
+
|
44 |
+
def _examples_count(self, meta):
|
45 |
+
if self.duration is None:
|
46 |
+
return 1
|
47 |
+
else:
|
48 |
+
return int((meta["duration"] - self.duration) // self.stride + 1)
|
49 |
+
|
50 |
+
def track_metadata(self, index):
|
51 |
+
for meta in self.metadata:
|
52 |
+
examples = self._examples_count(meta)
|
53 |
+
if index >= examples:
|
54 |
+
index -= examples
|
55 |
+
continue
|
56 |
+
return meta
|
57 |
+
|
58 |
+
def __getitem__(self, index):
|
59 |
+
for meta in self.metadata:
|
60 |
+
examples = self._examples_count(meta)
|
61 |
+
if index >= examples:
|
62 |
+
index -= examples
|
63 |
+
continue
|
64 |
+
streams = AudioFile(meta["path"]).read(seek_time=index * self.stride,
|
65 |
+
duration=self.duration,
|
66 |
+
channels=self.channels,
|
67 |
+
samplerate=self.samplerate,
|
68 |
+
streams=self.streams)
|
69 |
+
return (streams - meta["mean"]) / meta["std"]
|
70 |
+
|
71 |
+
|
72 |
+
def _get_track_metadata(path):
|
73 |
+
# use mono at 44kHz as reference. For any other settings data won't be perfectly
|
74 |
+
# normalized but it should be good enough.
|
75 |
+
audio = AudioFile(path)
|
76 |
+
mix = audio.read(streams=0, channels=1, samplerate=44100)
|
77 |
+
return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()}
|
78 |
+
|
79 |
+
|
80 |
+
def _build_metadata(tracks, workers=10):
|
81 |
+
pendings = []
|
82 |
+
with futures.ProcessPoolExecutor(workers) as pool:
|
83 |
+
for name, path in tracks.items():
|
84 |
+
pendings.append((name, pool.submit(_get_track_metadata, path)))
|
85 |
+
return {name: p.result() for name, p in pendings}
|
86 |
+
|
87 |
+
|
88 |
+
def _build_musdb_metadata(path, musdb, workers):
|
89 |
+
tracks = get_musdb_tracks(musdb)
|
90 |
+
metadata = _build_metadata(tracks, workers)
|
91 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
92 |
+
json.dump(metadata, open(path, "w"))
|
93 |
+
|
94 |
+
|
95 |
+
def get_compressed_datasets(args, samples):
|
96 |
+
metadata_file = args.metadata / "musdb.json"
|
97 |
+
if not metadata_file.is_file() and args.rank == 0:
|
98 |
+
_build_musdb_metadata(metadata_file, args.musdb, args.workers)
|
99 |
+
if args.world_size > 1:
|
100 |
+
distributed.barrier()
|
101 |
+
metadata = json.load(open(metadata_file))
|
102 |
+
duration = Fraction(samples, args.samplerate)
|
103 |
+
stride = Fraction(args.data_stride, args.samplerate)
|
104 |
+
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
|
105 |
+
metadata,
|
106 |
+
duration=duration,
|
107 |
+
stride=stride,
|
108 |
+
streams=slice(1, None),
|
109 |
+
samplerate=args.samplerate,
|
110 |
+
channels=args.audio_channels)
|
111 |
+
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
|
112 |
+
metadata,
|
113 |
+
samplerate=args.samplerate,
|
114 |
+
channels=args.audio_channels)
|
115 |
+
return train_set, valid_set
|
demucs/model.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import julius
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from .utils import capture_init, center_trim
|
13 |
+
|
14 |
+
|
15 |
+
class BLSTM(nn.Module):
|
16 |
+
def __init__(self, dim, layers=1):
|
17 |
+
super().__init__()
|
18 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
19 |
+
self.linear = nn.Linear(2 * dim, dim)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x.permute(2, 0, 1)
|
23 |
+
x = self.lstm(x)[0]
|
24 |
+
x = self.linear(x)
|
25 |
+
x = x.permute(1, 2, 0)
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
def rescale_conv(conv, reference):
|
30 |
+
std = conv.weight.std().detach()
|
31 |
+
scale = (std / reference)**0.5
|
32 |
+
conv.weight.data /= scale
|
33 |
+
if conv.bias is not None:
|
34 |
+
conv.bias.data /= scale
|
35 |
+
|
36 |
+
|
37 |
+
def rescale_module(module, reference):
|
38 |
+
for sub in module.modules():
|
39 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
40 |
+
rescale_conv(sub, reference)
|
41 |
+
|
42 |
+
|
43 |
+
class Demucs(nn.Module):
|
44 |
+
@capture_init
|
45 |
+
def __init__(self,
|
46 |
+
sources,
|
47 |
+
audio_channels=2,
|
48 |
+
channels=64,
|
49 |
+
depth=6,
|
50 |
+
rewrite=True,
|
51 |
+
glu=True,
|
52 |
+
rescale=0.1,
|
53 |
+
resample=True,
|
54 |
+
kernel_size=8,
|
55 |
+
stride=4,
|
56 |
+
growth=2.,
|
57 |
+
lstm_layers=2,
|
58 |
+
context=3,
|
59 |
+
normalize=False,
|
60 |
+
samplerate=44100,
|
61 |
+
segment_length=4 * 10 * 44100):
|
62 |
+
"""
|
63 |
+
Args:
|
64 |
+
sources (list[str]): list of source names
|
65 |
+
audio_channels (int): stereo or mono
|
66 |
+
channels (int): first convolution channels
|
67 |
+
depth (int): number of encoder/decoder layers
|
68 |
+
rewrite (bool): add 1x1 convolution to each encoder layer
|
69 |
+
and a convolution to each decoder layer.
|
70 |
+
For the decoder layer, `context` gives the kernel size.
|
71 |
+
glu (bool): use glu instead of ReLU
|
72 |
+
resample_input (bool): upsample x2 the input and downsample /2 the output.
|
73 |
+
rescale (int): rescale initial weights of convolutions
|
74 |
+
to get their standard deviation closer to `rescale`
|
75 |
+
kernel_size (int): kernel size for convolutions
|
76 |
+
stride (int): stride for convolutions
|
77 |
+
growth (float): multiply (resp divide) number of channels by that
|
78 |
+
for each layer of the encoder (resp decoder)
|
79 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm
|
80 |
+
context (int): kernel size of the convolution in the
|
81 |
+
decoder before the transposed convolution. If > 1,
|
82 |
+
will provide some context from neighboring time
|
83 |
+
steps.
|
84 |
+
samplerate (int): stored as meta information for easing
|
85 |
+
future evaluations of the model.
|
86 |
+
segment_length (int): stored as meta information for easing
|
87 |
+
future evaluations of the model. Length of the segments on which
|
88 |
+
the model was trained.
|
89 |
+
"""
|
90 |
+
|
91 |
+
super().__init__()
|
92 |
+
self.audio_channels = audio_channels
|
93 |
+
self.sources = sources
|
94 |
+
self.kernel_size = kernel_size
|
95 |
+
self.context = context
|
96 |
+
self.stride = stride
|
97 |
+
self.depth = depth
|
98 |
+
self.resample = resample
|
99 |
+
self.channels = channels
|
100 |
+
self.normalize = normalize
|
101 |
+
self.samplerate = samplerate
|
102 |
+
self.segment_length = segment_length
|
103 |
+
|
104 |
+
self.encoder = nn.ModuleList()
|
105 |
+
self.decoder = nn.ModuleList()
|
106 |
+
|
107 |
+
if glu:
|
108 |
+
activation = nn.GLU(dim=1)
|
109 |
+
ch_scale = 2
|
110 |
+
else:
|
111 |
+
activation = nn.ReLU()
|
112 |
+
ch_scale = 1
|
113 |
+
in_channels = audio_channels
|
114 |
+
for index in range(depth):
|
115 |
+
encode = []
|
116 |
+
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
|
117 |
+
if rewrite:
|
118 |
+
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
|
119 |
+
self.encoder.append(nn.Sequential(*encode))
|
120 |
+
|
121 |
+
decode = []
|
122 |
+
if index > 0:
|
123 |
+
out_channels = in_channels
|
124 |
+
else:
|
125 |
+
out_channels = len(self.sources) * audio_channels
|
126 |
+
if rewrite:
|
127 |
+
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
|
128 |
+
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
|
129 |
+
if index > 0:
|
130 |
+
decode.append(nn.ReLU())
|
131 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
132 |
+
in_channels = channels
|
133 |
+
channels = int(growth * channels)
|
134 |
+
|
135 |
+
channels = in_channels
|
136 |
+
|
137 |
+
if lstm_layers:
|
138 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
139 |
+
else:
|
140 |
+
self.lstm = None
|
141 |
+
|
142 |
+
if rescale:
|
143 |
+
rescale_module(self, reference=rescale)
|
144 |
+
|
145 |
+
def valid_length(self, length):
|
146 |
+
"""
|
147 |
+
Return the nearest valid length to use with the model so that
|
148 |
+
there is no time steps left over in a convolutions, e.g. for all
|
149 |
+
layers, size of the input - kernel_size % stride = 0.
|
150 |
+
|
151 |
+
If the mixture has a valid length, the estimated sources
|
152 |
+
will have exactly the same length when context = 1. If context > 1,
|
153 |
+
the two signals can be center trimmed to match.
|
154 |
+
|
155 |
+
For training, extracts should have a valid length.For evaluation
|
156 |
+
on full tracks we recommend passing `pad = True` to :method:`forward`.
|
157 |
+
"""
|
158 |
+
if self.resample:
|
159 |
+
length *= 2
|
160 |
+
for _ in range(self.depth):
|
161 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
162 |
+
length = max(1, length)
|
163 |
+
length += self.context - 1
|
164 |
+
for _ in range(self.depth):
|
165 |
+
length = (length - 1) * self.stride + self.kernel_size
|
166 |
+
|
167 |
+
if self.resample:
|
168 |
+
length = math.ceil(length / 2)
|
169 |
+
return int(length)
|
170 |
+
|
171 |
+
def forward(self, mix):
|
172 |
+
x = mix
|
173 |
+
|
174 |
+
if self.normalize:
|
175 |
+
mono = mix.mean(dim=1, keepdim=True)
|
176 |
+
mean = mono.mean(dim=-1, keepdim=True)
|
177 |
+
std = mono.std(dim=-1, keepdim=True)
|
178 |
+
else:
|
179 |
+
mean = 0
|
180 |
+
std = 1
|
181 |
+
|
182 |
+
x = (x - mean) / (1e-5 + std)
|
183 |
+
|
184 |
+
if self.resample:
|
185 |
+
x = julius.resample_frac(x, 1, 2)
|
186 |
+
|
187 |
+
saved = []
|
188 |
+
for encode in self.encoder:
|
189 |
+
x = encode(x)
|
190 |
+
saved.append(x)
|
191 |
+
if self.lstm:
|
192 |
+
x = self.lstm(x)
|
193 |
+
for decode in self.decoder:
|
194 |
+
skip = center_trim(saved.pop(-1), x)
|
195 |
+
x = x + skip
|
196 |
+
x = decode(x)
|
197 |
+
|
198 |
+
if self.resample:
|
199 |
+
x = julius.resample_frac(x, 2, 1)
|
200 |
+
x = x * std + mean
|
201 |
+
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
202 |
+
return x
|
demucs/parser.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
|
12 |
+
def get_parser():
|
13 |
+
parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.")
|
14 |
+
default_raw = None
|
15 |
+
default_musdb = None
|
16 |
+
if 'DEMUCS_RAW' in os.environ:
|
17 |
+
default_raw = Path(os.environ['DEMUCS_RAW'])
|
18 |
+
if 'DEMUCS_MUSDB' in os.environ:
|
19 |
+
default_musdb = Path(os.environ['DEMUCS_MUSDB'])
|
20 |
+
parser.add_argument(
|
21 |
+
"--raw",
|
22 |
+
type=Path,
|
23 |
+
default=default_raw,
|
24 |
+
help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.")
|
25 |
+
parser.add_argument("--no_raw", action="store_const", const=None, dest="raw")
|
26 |
+
parser.add_argument("-m",
|
27 |
+
"--musdb",
|
28 |
+
type=Path,
|
29 |
+
default=default_musdb,
|
30 |
+
help="Path to musdb root")
|
31 |
+
parser.add_argument("--is_wav", action="store_true",
|
32 |
+
help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).")
|
33 |
+
parser.add_argument("--metadata", type=Path, default=Path("metadata/"),
|
34 |
+
help="Folder where metadata information is stored.")
|
35 |
+
parser.add_argument("--wav", type=Path,
|
36 |
+
help="Path to a wav dataset. This should contain a 'train' and a 'valid' "
|
37 |
+
"subfolder.")
|
38 |
+
parser.add_argument("--samplerate", type=int, default=44100)
|
39 |
+
parser.add_argument("--audio_channels", type=int, default=2)
|
40 |
+
parser.add_argument("--samples",
|
41 |
+
default=44100 * 10,
|
42 |
+
type=int,
|
43 |
+
help="number of samples to feed in")
|
44 |
+
parser.add_argument("--data_stride",
|
45 |
+
default=44100,
|
46 |
+
type=int,
|
47 |
+
help="Stride for chunks, shorter = longer epochs")
|
48 |
+
parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers")
|
49 |
+
parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers")
|
50 |
+
parser.add_argument("-d",
|
51 |
+
"--device",
|
52 |
+
help="Device to train on, default is cuda if available else cpu")
|
53 |
+
parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.")
|
54 |
+
parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file")
|
55 |
+
parser.add_argument("--test", help="Just run the test pipeline + one validation. "
|
56 |
+
"This should be a filename relative to the models/ folder.")
|
57 |
+
parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, "
|
58 |
+
"on a pretrained model. ")
|
59 |
+
|
60 |
+
parser.add_argument("--rank", default=0, type=int)
|
61 |
+
parser.add_argument("--world_size", default=1, type=int)
|
62 |
+
parser.add_argument("--master")
|
63 |
+
|
64 |
+
parser.add_argument("--checkpoints",
|
65 |
+
type=Path,
|
66 |
+
default=Path("checkpoints"),
|
67 |
+
help="Folder where to store checkpoints etc")
|
68 |
+
parser.add_argument("--evals",
|
69 |
+
type=Path,
|
70 |
+
default=Path("evals"),
|
71 |
+
help="Folder where to store evals and waveforms")
|
72 |
+
parser.add_argument("--save",
|
73 |
+
action="store_true",
|
74 |
+
help="Save estimated for the test set waveforms")
|
75 |
+
parser.add_argument("--logs",
|
76 |
+
type=Path,
|
77 |
+
default=Path("logs"),
|
78 |
+
help="Folder where to store logs")
|
79 |
+
parser.add_argument("--models",
|
80 |
+
type=Path,
|
81 |
+
default=Path("models"),
|
82 |
+
help="Folder where to store trained models")
|
83 |
+
parser.add_argument("-R",
|
84 |
+
"--restart",
|
85 |
+
action='store_true',
|
86 |
+
help='Restart training, ignoring previous run')
|
87 |
+
|
88 |
+
parser.add_argument("--seed", type=int, default=42)
|
89 |
+
parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs")
|
90 |
+
parser.add_argument("-r",
|
91 |
+
"--repeat",
|
92 |
+
type=int,
|
93 |
+
default=2,
|
94 |
+
help="Repeat the train set, longer epochs")
|
95 |
+
parser.add_argument("-b", "--batch_size", type=int, default=64)
|
96 |
+
parser.add_argument("--lr", type=float, default=3e-4)
|
97 |
+
parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1")
|
98 |
+
parser.add_argument("--init", help="Initialize from a pre-trained model.")
|
99 |
+
|
100 |
+
# Augmentation options
|
101 |
+
parser.add_argument("--no_augment",
|
102 |
+
action="store_false",
|
103 |
+
dest="augment",
|
104 |
+
default=True,
|
105 |
+
help="No basic data augmentation.")
|
106 |
+
parser.add_argument("--repitch", type=float, default=0.2,
|
107 |
+
help="Probability to do tempo/pitch change")
|
108 |
+
parser.add_argument("--max_tempo", type=float, default=12,
|
109 |
+
help="Maximum relative tempo change in %% when using repitch.")
|
110 |
+
|
111 |
+
parser.add_argument("--remix_group_size",
|
112 |
+
type=int,
|
113 |
+
default=4,
|
114 |
+
help="Shuffle sources using group of this size. Useful to somewhat "
|
115 |
+
"replicate multi-gpu training "
|
116 |
+
"on less GPUs.")
|
117 |
+
parser.add_argument("--shifts",
|
118 |
+
type=int,
|
119 |
+
default=10,
|
120 |
+
help="Number of random shifts used for the shift trick.")
|
121 |
+
parser.add_argument("--overlap",
|
122 |
+
type=float,
|
123 |
+
default=0.25,
|
124 |
+
help="Overlap when --split_valid is passed.")
|
125 |
+
|
126 |
+
# See model.py for doc
|
127 |
+
parser.add_argument("--growth",
|
128 |
+
type=float,
|
129 |
+
default=2.,
|
130 |
+
help="Number of channels between two layers will increase by this factor")
|
131 |
+
parser.add_argument("--depth",
|
132 |
+
type=int,
|
133 |
+
default=6,
|
134 |
+
help="Number of layers for the encoder and decoder")
|
135 |
+
parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM")
|
136 |
+
parser.add_argument("--channels",
|
137 |
+
type=int,
|
138 |
+
default=64,
|
139 |
+
help="Number of channels for the first encoder layer")
|
140 |
+
parser.add_argument("--kernel_size",
|
141 |
+
type=int,
|
142 |
+
default=8,
|
143 |
+
help="Kernel size for the (transposed) convolutions")
|
144 |
+
parser.add_argument("--conv_stride",
|
145 |
+
type=int,
|
146 |
+
default=4,
|
147 |
+
help="Stride for the (transposed) convolutions")
|
148 |
+
parser.add_argument("--context",
|
149 |
+
type=int,
|
150 |
+
default=3,
|
151 |
+
help="Context size for the decoder convolutions "
|
152 |
+
"before the transposed convolutions")
|
153 |
+
parser.add_argument("--rescale",
|
154 |
+
type=float,
|
155 |
+
default=0.1,
|
156 |
+
help="Initial weight rescale reference")
|
157 |
+
parser.add_argument("--no_resample", action="store_false",
|
158 |
+
default=True, dest="resample",
|
159 |
+
help="No Resampling of the input/output x2")
|
160 |
+
parser.add_argument("--no_glu",
|
161 |
+
action="store_false",
|
162 |
+
default=True,
|
163 |
+
dest="glu",
|
164 |
+
help="Replace all GLUs by ReLUs")
|
165 |
+
parser.add_argument("--no_rewrite",
|
166 |
+
action="store_false",
|
167 |
+
default=True,
|
168 |
+
dest="rewrite",
|
169 |
+
help="No 1x1 rewrite convolutions")
|
170 |
+
parser.add_argument("--normalize", action="store_true")
|
171 |
+
parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True)
|
172 |
+
|
173 |
+
# Tasnet options
|
174 |
+
parser.add_argument("--tasnet", action="store_true")
|
175 |
+
parser.add_argument("--split_valid",
|
176 |
+
action="store_true",
|
177 |
+
help="Predict chunks by chunks for valid and test. Required for tasnet")
|
178 |
+
parser.add_argument("--X", type=int, default=8)
|
179 |
+
|
180 |
+
# Other options
|
181 |
+
parser.add_argument("--show",
|
182 |
+
action="store_true",
|
183 |
+
help="Show model architecture, size and exit")
|
184 |
+
parser.add_argument("--save_model", action="store_true",
|
185 |
+
help="Skip traning, just save final model "
|
186 |
+
"for the current checkpoint value.")
|
187 |
+
parser.add_argument("--save_state",
|
188 |
+
help="Skip training, just save state "
|
189 |
+
"for the current checkpoint value. You should "
|
190 |
+
"provide a model name as argument.")
|
191 |
+
|
192 |
+
# Quantization options
|
193 |
+
parser.add_argument("--q-min-size", type=float, default=1,
|
194 |
+
help="Only quantize layers over this size (in MB)")
|
195 |
+
parser.add_argument(
|
196 |
+
"--qat", type=int, help="If provided, use QAT training with that many bits.")
|
197 |
+
|
198 |
+
parser.add_argument("--diffq", type=float, default=0)
|
199 |
+
parser.add_argument(
|
200 |
+
"--ms-target", type=float, default=162,
|
201 |
+
help="Model size target in MB, when using DiffQ. Best model will be kept "
|
202 |
+
"only if it is smaller than this target.")
|
203 |
+
|
204 |
+
return parser
|
205 |
+
|
206 |
+
|
207 |
+
def get_name(parser, args):
|
208 |
+
"""
|
209 |
+
Return the name of an experiment given the args. Some parameters are ignored,
|
210 |
+
for instance --workers, as they do not impact the final result.
|
211 |
+
"""
|
212 |
+
ignore_args = set([
|
213 |
+
"checkpoints",
|
214 |
+
"deterministic",
|
215 |
+
"eval",
|
216 |
+
"evals",
|
217 |
+
"eval_cpu",
|
218 |
+
"eval_workers",
|
219 |
+
"logs",
|
220 |
+
"master",
|
221 |
+
"rank",
|
222 |
+
"restart",
|
223 |
+
"save",
|
224 |
+
"save_model",
|
225 |
+
"save_state",
|
226 |
+
"show",
|
227 |
+
"workers",
|
228 |
+
"world_size",
|
229 |
+
])
|
230 |
+
parts = []
|
231 |
+
name_args = dict(args.__dict__)
|
232 |
+
for name, value in name_args.items():
|
233 |
+
if name in ignore_args:
|
234 |
+
continue
|
235 |
+
if value != parser.get_default(name):
|
236 |
+
if isinstance(value, Path):
|
237 |
+
parts.append(f"{name}={value.name}")
|
238 |
+
else:
|
239 |
+
parts.append(f"{name}={value}")
|
240 |
+
if parts:
|
241 |
+
name = " ".join(parts)
|
242 |
+
else:
|
243 |
+
name = "default"
|
244 |
+
return name
|
demucs/pretrained.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
from diffq import DiffQuantizer
|
11 |
+
import torch.hub
|
12 |
+
|
13 |
+
from .model import Demucs
|
14 |
+
from .tasnet import ConvTasNet
|
15 |
+
from .utils import set_state
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
|
19 |
+
|
20 |
+
PRETRAINED_MODELS = {
|
21 |
+
'demucs': 'e07c671f',
|
22 |
+
'demucs48_hq': '28a1282c',
|
23 |
+
'demucs_extra': '3646af93',
|
24 |
+
'demucs_quantized': '07afea75',
|
25 |
+
'tasnet': 'beb46fac',
|
26 |
+
'tasnet_extra': 'df3777b2',
|
27 |
+
'demucs_unittest': '09ebc15f',
|
28 |
+
}
|
29 |
+
|
30 |
+
SOURCES = ["drums", "bass", "other", "vocals"]
|
31 |
+
|
32 |
+
|
33 |
+
def get_url(name):
|
34 |
+
sig = PRETRAINED_MODELS[name]
|
35 |
+
return ROOT + name + "-" + sig[:8] + ".th"
|
36 |
+
|
37 |
+
|
38 |
+
def is_pretrained(name):
|
39 |
+
return name in PRETRAINED_MODELS
|
40 |
+
|
41 |
+
|
42 |
+
def load_pretrained(name):
|
43 |
+
if name == "demucs":
|
44 |
+
return demucs(pretrained=True)
|
45 |
+
elif name == "demucs48_hq":
|
46 |
+
return demucs(pretrained=True, hq=True, channels=48)
|
47 |
+
elif name == "demucs_extra":
|
48 |
+
return demucs(pretrained=True, extra=True)
|
49 |
+
elif name == "demucs_quantized":
|
50 |
+
return demucs(pretrained=True, quantized=True)
|
51 |
+
elif name == "demucs_unittest":
|
52 |
+
return demucs_unittest(pretrained=True)
|
53 |
+
elif name == "tasnet":
|
54 |
+
return tasnet(pretrained=True)
|
55 |
+
elif name == "tasnet_extra":
|
56 |
+
return tasnet(pretrained=True, extra=True)
|
57 |
+
else:
|
58 |
+
raise ValueError(f"Invalid pretrained name {name}")
|
59 |
+
|
60 |
+
|
61 |
+
def _load_state(name, model, quantizer=None):
|
62 |
+
url = get_url(name)
|
63 |
+
state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
|
64 |
+
set_state(model, quantizer, state)
|
65 |
+
if quantizer:
|
66 |
+
quantizer.detach()
|
67 |
+
|
68 |
+
|
69 |
+
def demucs_unittest(pretrained=True):
|
70 |
+
model = Demucs(channels=4, sources=SOURCES)
|
71 |
+
if pretrained:
|
72 |
+
_load_state('demucs_unittest', model)
|
73 |
+
return model
|
74 |
+
|
75 |
+
|
76 |
+
def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
|
77 |
+
if not pretrained and (extra or quantized or hq):
|
78 |
+
raise ValueError("if extra or quantized is True, pretrained must be True.")
|
79 |
+
model = Demucs(sources=SOURCES, channels=channels)
|
80 |
+
if pretrained:
|
81 |
+
name = 'demucs'
|
82 |
+
if channels != 64:
|
83 |
+
name += str(channels)
|
84 |
+
quantizer = None
|
85 |
+
if sum([extra, quantized, hq]) > 1:
|
86 |
+
raise ValueError("Only one of extra, quantized, hq, can be True.")
|
87 |
+
if quantized:
|
88 |
+
quantizer = DiffQuantizer(model, group_size=8, min_size=1)
|
89 |
+
name += '_quantized'
|
90 |
+
if extra:
|
91 |
+
name += '_extra'
|
92 |
+
if hq:
|
93 |
+
name += '_hq'
|
94 |
+
_load_state(name, model, quantizer)
|
95 |
+
return model
|
96 |
+
|
97 |
+
|
98 |
+
def tasnet(pretrained=True, extra=False):
|
99 |
+
if not pretrained and extra:
|
100 |
+
raise ValueError("if extra is True, pretrained must be True.")
|
101 |
+
model = ConvTasNet(X=10, sources=SOURCES)
|
102 |
+
if pretrained:
|
103 |
+
name = 'tasnet'
|
104 |
+
if extra:
|
105 |
+
name = 'tasnet_extra'
|
106 |
+
_load_state(name, model)
|
107 |
+
return model
|
demucs/raw.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
from collections import defaultdict, namedtuple
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import musdb
|
13 |
+
import numpy as np
|
14 |
+
import torch as th
|
15 |
+
import tqdm
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
|
18 |
+
from .audio import AudioFile
|
19 |
+
|
20 |
+
ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"])
|
21 |
+
|
22 |
+
|
23 |
+
class Rawset:
|
24 |
+
"""
|
25 |
+
Dataset of raw, normalized, float32 audio files
|
26 |
+
"""
|
27 |
+
def __init__(self, path, samples=None, stride=None, channels=2, streams=None):
|
28 |
+
self.path = Path(path)
|
29 |
+
self.channels = channels
|
30 |
+
self.samples = samples
|
31 |
+
if stride is None:
|
32 |
+
stride = samples if samples is not None else 0
|
33 |
+
self.stride = stride
|
34 |
+
entries = defaultdict(list)
|
35 |
+
for root, folders, files in os.walk(self.path, followlinks=True):
|
36 |
+
folders.sort()
|
37 |
+
files.sort()
|
38 |
+
for file in files:
|
39 |
+
if file.endswith(".raw"):
|
40 |
+
path = Path(root) / file
|
41 |
+
name, stream = path.stem.rsplit('.', 1)
|
42 |
+
entries[(path.parent.relative_to(self.path), name)].append(int(stream))
|
43 |
+
|
44 |
+
self._entries = list(entries.keys())
|
45 |
+
|
46 |
+
sizes = []
|
47 |
+
self._lengths = []
|
48 |
+
ref_streams = sorted(entries[self._entries[0]])
|
49 |
+
assert ref_streams == list(range(len(ref_streams)))
|
50 |
+
if streams is None:
|
51 |
+
self.streams = ref_streams
|
52 |
+
else:
|
53 |
+
self.streams = streams
|
54 |
+
for entry in sorted(entries.keys()):
|
55 |
+
streams = entries[entry]
|
56 |
+
assert sorted(streams) == ref_streams
|
57 |
+
file = self._path(*entry)
|
58 |
+
length = file.stat().st_size // (4 * channels)
|
59 |
+
if samples is None:
|
60 |
+
sizes.append(1)
|
61 |
+
else:
|
62 |
+
if length < samples:
|
63 |
+
self._entries.remove(entry)
|
64 |
+
continue
|
65 |
+
sizes.append((length - samples) // stride + 1)
|
66 |
+
self._lengths.append(length)
|
67 |
+
if not sizes:
|
68 |
+
raise ValueError(f"Empty dataset {self.path}")
|
69 |
+
self._cumulative_sizes = np.cumsum(sizes)
|
70 |
+
self._sizes = sizes
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return self._cumulative_sizes[-1]
|
74 |
+
|
75 |
+
@property
|
76 |
+
def total_length(self):
|
77 |
+
return sum(self._lengths)
|
78 |
+
|
79 |
+
def chunk_info(self, index):
|
80 |
+
file_index = np.searchsorted(self._cumulative_sizes, index, side='right')
|
81 |
+
if file_index == 0:
|
82 |
+
local_index = index
|
83 |
+
else:
|
84 |
+
local_index = index - self._cumulative_sizes[file_index - 1]
|
85 |
+
return ChunkInfo(offset=local_index * self.stride,
|
86 |
+
file_index=file_index,
|
87 |
+
local_index=local_index)
|
88 |
+
|
89 |
+
def _path(self, folder, name, stream=0):
|
90 |
+
return self.path / folder / (name + f'.{stream}.raw')
|
91 |
+
|
92 |
+
def __getitem__(self, index):
|
93 |
+
chunk = self.chunk_info(index)
|
94 |
+
entry = self._entries[chunk.file_index]
|
95 |
+
|
96 |
+
length = self.samples or self._lengths[chunk.file_index]
|
97 |
+
streams = []
|
98 |
+
to_read = length * self.channels * 4
|
99 |
+
for stream_index, stream in enumerate(self.streams):
|
100 |
+
offset = chunk.offset * 4 * self.channels
|
101 |
+
file = open(self._path(*entry, stream=stream), 'rb')
|
102 |
+
file.seek(offset)
|
103 |
+
content = file.read(to_read)
|
104 |
+
assert len(content) == to_read
|
105 |
+
content = np.frombuffer(content, dtype=np.float32)
|
106 |
+
content = content.copy() # make writable
|
107 |
+
streams.append(th.from_numpy(content).view(length, self.channels).t())
|
108 |
+
return th.stack(streams, dim=0)
|
109 |
+
|
110 |
+
def name(self, index):
|
111 |
+
chunk = self.chunk_info(index)
|
112 |
+
folder, name = self._entries[chunk.file_index]
|
113 |
+
return folder / name
|
114 |
+
|
115 |
+
|
116 |
+
class MusDBSet:
|
117 |
+
def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2):
|
118 |
+
self.mus = mus
|
119 |
+
self.streams = streams
|
120 |
+
self.samplerate = samplerate
|
121 |
+
self.channels = channels
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.mus.tracks)
|
125 |
+
|
126 |
+
def __getitem__(self, index):
|
127 |
+
track = self.mus.tracks[index]
|
128 |
+
return (track.name, AudioFile(track.path).read(channels=self.channels,
|
129 |
+
seek_time=0,
|
130 |
+
streams=self.streams,
|
131 |
+
samplerate=self.samplerate))
|
132 |
+
|
133 |
+
|
134 |
+
def build_raw(mus, destination, normalize, workers, samplerate, channels):
|
135 |
+
destination.mkdir(parents=True, exist_ok=True)
|
136 |
+
loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate),
|
137 |
+
batch_size=1,
|
138 |
+
num_workers=workers,
|
139 |
+
collate_fn=lambda x: x[0])
|
140 |
+
for name, streams in tqdm.tqdm(loader):
|
141 |
+
if normalize:
|
142 |
+
ref = streams[0].mean(dim=0) # use mono mixture as reference
|
143 |
+
streams = (streams - ref.mean()) / ref.std()
|
144 |
+
for index, stream in enumerate(streams):
|
145 |
+
open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes())
|
146 |
+
|
147 |
+
|
148 |
+
def main():
|
149 |
+
parser = argparse.ArgumentParser('rawset')
|
150 |
+
parser.add_argument('--workers', type=int, default=10)
|
151 |
+
parser.add_argument('--samplerate', type=int, default=44100)
|
152 |
+
parser.add_argument('--channels', type=int, default=2)
|
153 |
+
parser.add_argument('musdb', type=Path)
|
154 |
+
parser.add_argument('destination', type=Path)
|
155 |
+
|
156 |
+
args = parser.parse_args()
|
157 |
+
|
158 |
+
build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"),
|
159 |
+
args.destination / "train",
|
160 |
+
normalize=True,
|
161 |
+
channels=args.channels,
|
162 |
+
samplerate=args.samplerate,
|
163 |
+
workers=args.workers)
|
164 |
+
build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"),
|
165 |
+
args.destination / "valid",
|
166 |
+
normalize=True,
|
167 |
+
samplerate=args.samplerate,
|
168 |
+
channels=args.channels,
|
169 |
+
workers=args.workers)
|
170 |
+
|
171 |
+
|
172 |
+
if __name__ == "__main__":
|
173 |
+
main()
|
demucs/repitch.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import io
|
8 |
+
import random
|
9 |
+
import subprocess as sp
|
10 |
+
import tempfile
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from scipy.io import wavfile
|
15 |
+
|
16 |
+
|
17 |
+
def i16_pcm(wav):
|
18 |
+
if wav.dtype == np.int16:
|
19 |
+
return wav
|
20 |
+
return (wav * 2**15).clamp_(-2**15, 2**15 - 1).short()
|
21 |
+
|
22 |
+
|
23 |
+
def f32_pcm(wav):
|
24 |
+
if wav.dtype == np.float:
|
25 |
+
return wav
|
26 |
+
return wav.float() / 2**15
|
27 |
+
|
28 |
+
|
29 |
+
class RepitchedWrapper:
|
30 |
+
"""
|
31 |
+
Wrap a dataset to apply online change of pitch / tempo.
|
32 |
+
"""
|
33 |
+
def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12, tempo_std=5, vocals=[3]):
|
34 |
+
self.dataset = dataset
|
35 |
+
self.proba = proba
|
36 |
+
self.max_pitch = max_pitch
|
37 |
+
self.max_tempo = max_tempo
|
38 |
+
self.tempo_std = tempo_std
|
39 |
+
self.vocals = vocals
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return len(self.dataset)
|
43 |
+
|
44 |
+
def __getitem__(self, index):
|
45 |
+
streams = self.dataset[index]
|
46 |
+
in_length = streams.shape[-1]
|
47 |
+
out_length = int((1 - 0.01 * self.max_tempo) * in_length)
|
48 |
+
|
49 |
+
if random.random() < self.proba:
|
50 |
+
delta_pitch = random.randint(-self.max_pitch, self.max_pitch)
|
51 |
+
delta_tempo = random.gauss(0, self.tempo_std)
|
52 |
+
delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo)
|
53 |
+
outs = []
|
54 |
+
for idx, stream in enumerate(streams):
|
55 |
+
stream = repitch(
|
56 |
+
stream,
|
57 |
+
delta_pitch,
|
58 |
+
delta_tempo,
|
59 |
+
voice=idx in self.vocals)
|
60 |
+
outs.append(stream[:, :out_length])
|
61 |
+
streams = torch.stack(outs)
|
62 |
+
else:
|
63 |
+
streams = streams[..., :out_length]
|
64 |
+
return streams
|
65 |
+
|
66 |
+
|
67 |
+
def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
|
68 |
+
"""
|
69 |
+
tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
|
70 |
+
pitch is in semi tones.
|
71 |
+
Requires `soundstretch` to be installed, see
|
72 |
+
https://www.surina.net/soundtouch/soundstretch.html
|
73 |
+
"""
|
74 |
+
outfile = tempfile.NamedTemporaryFile(suffix=".wav")
|
75 |
+
in_ = io.BytesIO()
|
76 |
+
wavfile.write(in_, samplerate, i16_pcm(wav).t().numpy())
|
77 |
+
command = [
|
78 |
+
"soundstretch",
|
79 |
+
"stdin",
|
80 |
+
outfile.name,
|
81 |
+
f"-pitch={pitch}",
|
82 |
+
f"-tempo={tempo:.6f}",
|
83 |
+
]
|
84 |
+
if quick:
|
85 |
+
command += ["-quick"]
|
86 |
+
if voice:
|
87 |
+
command += ["-speech"]
|
88 |
+
try:
|
89 |
+
sp.run(command, capture_output=True, input=in_.getvalue(), check=True)
|
90 |
+
except sp.CalledProcessError as error:
|
91 |
+
raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
|
92 |
+
sr, wav = wavfile.read(outfile.name)
|
93 |
+
wav = wav.copy()
|
94 |
+
wav = f32_pcm(torch.from_numpy(wav).t())
|
95 |
+
assert sr == samplerate
|
96 |
+
return wav
|
demucs/separate.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import sys
|
9 |
+
from pathlib import Path
|
10 |
+
import subprocess
|
11 |
+
|
12 |
+
import julius
|
13 |
+
import torch as th
|
14 |
+
import torchaudio as ta
|
15 |
+
|
16 |
+
from .audio import AudioFile, convert_audio_channels
|
17 |
+
from .pretrained import is_pretrained, load_pretrained
|
18 |
+
from .utils import apply_model, load_model
|
19 |
+
|
20 |
+
|
21 |
+
def load_track(track, device, audio_channels, samplerate):
|
22 |
+
errors = {}
|
23 |
+
wav = None
|
24 |
+
|
25 |
+
try:
|
26 |
+
wav = AudioFile(track).read(
|
27 |
+
streams=0,
|
28 |
+
samplerate=samplerate,
|
29 |
+
channels=audio_channels).to(device)
|
30 |
+
except FileNotFoundError:
|
31 |
+
errors['ffmpeg'] = 'Ffmpeg is not installed.'
|
32 |
+
except subprocess.CalledProcessError:
|
33 |
+
errors['ffmpeg'] = 'FFmpeg could not read the file.'
|
34 |
+
|
35 |
+
if wav is None:
|
36 |
+
try:
|
37 |
+
wav, sr = ta.load(str(track))
|
38 |
+
except RuntimeError as err:
|
39 |
+
errors['torchaudio'] = err.args[0]
|
40 |
+
else:
|
41 |
+
wav = convert_audio_channels(wav, audio_channels)
|
42 |
+
wav = wav.to(device)
|
43 |
+
wav = julius.resample_frac(wav, sr, samplerate)
|
44 |
+
|
45 |
+
if wav is None:
|
46 |
+
print(f"Could not load file {track}. "
|
47 |
+
"Maybe it is not a supported file format? ")
|
48 |
+
for backend, error in errors.items():
|
49 |
+
print(f"When trying to load using {backend}, got the following error: {error}")
|
50 |
+
sys.exit(1)
|
51 |
+
return wav
|
52 |
+
|
53 |
+
|
54 |
+
def encode_mp3(wav, path, bitrate=320, samplerate=44100, channels=2, verbose=False):
|
55 |
+
try:
|
56 |
+
import lameenc
|
57 |
+
except ImportError:
|
58 |
+
print("Failed to call lame encoder. Maybe it is not installed? "
|
59 |
+
"On windows, run `python.exe -m pip install -U lameenc`, "
|
60 |
+
"on OSX/Linux, run `python3 -m pip install -U lameenc`, "
|
61 |
+
"then try again.", file=sys.stderr)
|
62 |
+
sys.exit(1)
|
63 |
+
encoder = lameenc.Encoder()
|
64 |
+
encoder.set_bit_rate(bitrate)
|
65 |
+
encoder.set_in_sample_rate(samplerate)
|
66 |
+
encoder.set_channels(channels)
|
67 |
+
encoder.set_quality(2) # 2-highest, 7-fastest
|
68 |
+
if not verbose:
|
69 |
+
encoder.silence()
|
70 |
+
wav = wav.transpose(0, 1).numpy()
|
71 |
+
mp3_data = encoder.encode(wav.tobytes())
|
72 |
+
mp3_data += encoder.flush()
|
73 |
+
with open(path, "wb") as f:
|
74 |
+
f.write(mp3_data)
|
75 |
+
|
76 |
+
|
77 |
+
def main():
|
78 |
+
parser = argparse.ArgumentParser("demucs.separate",
|
79 |
+
description="Separate the sources for the given tracks")
|
80 |
+
parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks')
|
81 |
+
parser.add_argument("-n",
|
82 |
+
"--name",
|
83 |
+
default="demucs_quantized",
|
84 |
+
help="Model name. See README.md for the list of pretrained models. "
|
85 |
+
"Default is demucs_quantized.")
|
86 |
+
parser.add_argument("-v", "--verbose", action="store_true")
|
87 |
+
parser.add_argument("-o",
|
88 |
+
"--out",
|
89 |
+
type=Path,
|
90 |
+
default=Path("separated"),
|
91 |
+
help="Folder where to put extracted tracks. A subfolder "
|
92 |
+
"with the model name will be created.")
|
93 |
+
parser.add_argument("--models",
|
94 |
+
type=Path,
|
95 |
+
default=Path("models"),
|
96 |
+
help="Path to trained models. "
|
97 |
+
"Also used to store downloaded pretrained models")
|
98 |
+
parser.add_argument("-d",
|
99 |
+
"--device",
|
100 |
+
default="cuda" if th.cuda.is_available() else "cpu",
|
101 |
+
help="Device to use, default is cuda if available else cpu")
|
102 |
+
parser.add_argument("--shifts",
|
103 |
+
default=0,
|
104 |
+
type=int,
|
105 |
+
help="Number of random shifts for equivariant stabilization."
|
106 |
+
"Increase separation time but improves quality for Demucs. 10 was used "
|
107 |
+
"in the original paper.")
|
108 |
+
parser.add_argument("--overlap",
|
109 |
+
default=0.25,
|
110 |
+
type=float,
|
111 |
+
help="Overlap between the splits.")
|
112 |
+
parser.add_argument("--no-split",
|
113 |
+
action="store_false",
|
114 |
+
dest="split",
|
115 |
+
default=True,
|
116 |
+
help="Doesn't split audio in chunks. This can use large amounts of memory.")
|
117 |
+
parser.add_argument("--float32",
|
118 |
+
action="store_true",
|
119 |
+
help="Convert the output wavefile to use pcm f32 format instead of s16. "
|
120 |
+
"This should not make a difference if you just plan on listening to the "
|
121 |
+
"audio but might be needed to compute exactly metrics like SDR etc.")
|
122 |
+
parser.add_argument("--int16",
|
123 |
+
action="store_false",
|
124 |
+
dest="float32",
|
125 |
+
help="Opposite of --float32, here for compatibility.")
|
126 |
+
parser.add_argument("--mp3", action="store_true",
|
127 |
+
help="Convert the output wavs to mp3.")
|
128 |
+
parser.add_argument("--mp3-bitrate",
|
129 |
+
default=320,
|
130 |
+
type=int,
|
131 |
+
help="Bitrate of converted mp3.")
|
132 |
+
|
133 |
+
args = parser.parse_args()
|
134 |
+
name = args.name + ".th"
|
135 |
+
model_path = args.models / name
|
136 |
+
if model_path.is_file():
|
137 |
+
model = load_model(model_path)
|
138 |
+
else:
|
139 |
+
if is_pretrained(args.name):
|
140 |
+
model = load_pretrained(args.name)
|
141 |
+
else:
|
142 |
+
print(f"No pre-trained model {args.name}", file=sys.stderr)
|
143 |
+
sys.exit(1)
|
144 |
+
model.to(args.device)
|
145 |
+
|
146 |
+
out = args.out / args.name
|
147 |
+
out.mkdir(parents=True, exist_ok=True)
|
148 |
+
print(f"Separated tracks will be stored in {out.resolve()}")
|
149 |
+
for track in args.tracks:
|
150 |
+
if not track.exists():
|
151 |
+
print(
|
152 |
+
f"File {track} does not exist. If the path contains spaces, "
|
153 |
+
"please try again after surrounding the entire path with quotes \"\".",
|
154 |
+
file=sys.stderr)
|
155 |
+
continue
|
156 |
+
print(f"Separating track {track}")
|
157 |
+
wav = load_track(track, args.device, model.audio_channels, model.samplerate)
|
158 |
+
|
159 |
+
ref = wav.mean(0)
|
160 |
+
wav = (wav - ref.mean()) / ref.std()
|
161 |
+
sources = apply_model(model, wav, shifts=args.shifts, split=args.split,
|
162 |
+
overlap=args.overlap, progress=True)
|
163 |
+
sources = sources * ref.std() + ref.mean()
|
164 |
+
|
165 |
+
track_folder = out / track.name.rsplit(".", 1)[0]
|
166 |
+
track_folder.mkdir(exist_ok=True)
|
167 |
+
for source, name in zip(sources, model.sources):
|
168 |
+
source = source / max(1.01 * source.abs().max(), 1)
|
169 |
+
if args.mp3 or not args.float32:
|
170 |
+
source = (source * 2**15).clamp_(-2**15, 2**15 - 1).short()
|
171 |
+
source = source.cpu()
|
172 |
+
stem = str(track_folder / name)
|
173 |
+
if args.mp3:
|
174 |
+
encode_mp3(source, stem + ".mp3",
|
175 |
+
bitrate=args.mp3_bitrate,
|
176 |
+
samplerate=model.samplerate,
|
177 |
+
channels=model.audio_channels,
|
178 |
+
verbose=args.verbose)
|
179 |
+
else:
|
180 |
+
wavname = str(track_folder / f"{name}.wav")
|
181 |
+
ta.save(wavname, source, sample_rate=model.samplerate)
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
main()
|
demucs/tasnet.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Created on 2018/12
|
8 |
+
# Author: Kaituo XU
|
9 |
+
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
|
10 |
+
# Here is the original license:
|
11 |
+
# The MIT License (MIT)
|
12 |
+
#
|
13 |
+
# Copyright (c) 2018 Kaituo XU
|
14 |
+
#
|
15 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
16 |
+
# of this software and associated documentation files (the "Software"), to deal
|
17 |
+
# in the Software without restriction, including without limitation the rights
|
18 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
19 |
+
# copies of the Software, and to permit persons to whom the Software is
|
20 |
+
# furnished to do so, subject to the following conditions:
|
21 |
+
#
|
22 |
+
# The above copyright notice and this permission notice shall be included in all
|
23 |
+
# copies or substantial portions of the Software.
|
24 |
+
#
|
25 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
26 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
27 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
28 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
29 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
30 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
31 |
+
# SOFTWARE.
|
32 |
+
|
33 |
+
import math
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import torch.nn as nn
|
37 |
+
import torch.nn.functional as F
|
38 |
+
|
39 |
+
from .utils import capture_init
|
40 |
+
|
41 |
+
EPS = 1e-8
|
42 |
+
|
43 |
+
|
44 |
+
def overlap_and_add(signal, frame_step):
|
45 |
+
outer_dimensions = signal.size()[:-2]
|
46 |
+
frames, frame_length = signal.size()[-2:]
|
47 |
+
|
48 |
+
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
49 |
+
subframe_step = frame_step // subframe_length
|
50 |
+
subframes_per_frame = frame_length // subframe_length
|
51 |
+
output_size = frame_step * (frames - 1) + frame_length
|
52 |
+
output_subframes = output_size // subframe_length
|
53 |
+
|
54 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
55 |
+
|
56 |
+
frame = torch.arange(0, output_subframes,
|
57 |
+
device=signal.device).unfold(0, subframes_per_frame, subframe_step)
|
58 |
+
frame = frame.long() # signal may in GPU or CPU
|
59 |
+
frame = frame.contiguous().view(-1)
|
60 |
+
|
61 |
+
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
62 |
+
result.index_add_(-2, frame, subframe_signal)
|
63 |
+
result = result.view(*outer_dimensions, -1)
|
64 |
+
return result
|
65 |
+
|
66 |
+
|
67 |
+
class ConvTasNet(nn.Module):
|
68 |
+
@capture_init
|
69 |
+
def __init__(self,
|
70 |
+
sources,
|
71 |
+
N=256,
|
72 |
+
L=20,
|
73 |
+
B=256,
|
74 |
+
H=512,
|
75 |
+
P=3,
|
76 |
+
X=8,
|
77 |
+
R=4,
|
78 |
+
audio_channels=2,
|
79 |
+
norm_type="gLN",
|
80 |
+
causal=False,
|
81 |
+
mask_nonlinear='relu',
|
82 |
+
samplerate=44100,
|
83 |
+
segment_length=44100 * 2 * 4):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
sources: list of sources
|
87 |
+
N: Number of filters in autoencoder
|
88 |
+
L: Length of the filters (in samples)
|
89 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
90 |
+
H: Number of channels in convolutional blocks
|
91 |
+
P: Kernel size in convolutional blocks
|
92 |
+
X: Number of convolutional blocks in each repeat
|
93 |
+
R: Number of repeats
|
94 |
+
norm_type: BN, gLN, cLN
|
95 |
+
causal: causal or non-causal
|
96 |
+
mask_nonlinear: use which non-linear function to generate mask
|
97 |
+
"""
|
98 |
+
super(ConvTasNet, self).__init__()
|
99 |
+
# Hyper-parameter
|
100 |
+
self.sources = sources
|
101 |
+
self.C = len(sources)
|
102 |
+
self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
|
103 |
+
self.norm_type = norm_type
|
104 |
+
self.causal = causal
|
105 |
+
self.mask_nonlinear = mask_nonlinear
|
106 |
+
self.audio_channels = audio_channels
|
107 |
+
self.samplerate = samplerate
|
108 |
+
self.segment_length = segment_length
|
109 |
+
# Components
|
110 |
+
self.encoder = Encoder(L, N, audio_channels)
|
111 |
+
self.separator = TemporalConvNet(
|
112 |
+
N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
|
113 |
+
self.decoder = Decoder(N, L, audio_channels)
|
114 |
+
# init
|
115 |
+
for p in self.parameters():
|
116 |
+
if p.dim() > 1:
|
117 |
+
nn.init.xavier_normal_(p)
|
118 |
+
|
119 |
+
def valid_length(self, length):
|
120 |
+
return length
|
121 |
+
|
122 |
+
def forward(self, mixture):
|
123 |
+
"""
|
124 |
+
Args:
|
125 |
+
mixture: [M, T], M is batch size, T is #samples
|
126 |
+
Returns:
|
127 |
+
est_source: [M, C, T]
|
128 |
+
"""
|
129 |
+
mixture_w = self.encoder(mixture)
|
130 |
+
est_mask = self.separator(mixture_w)
|
131 |
+
est_source = self.decoder(mixture_w, est_mask)
|
132 |
+
|
133 |
+
# T changed after conv1d in encoder, fix it here
|
134 |
+
T_origin = mixture.size(-1)
|
135 |
+
T_conv = est_source.size(-1)
|
136 |
+
est_source = F.pad(est_source, (0, T_origin - T_conv))
|
137 |
+
return est_source
|
138 |
+
|
139 |
+
|
140 |
+
class Encoder(nn.Module):
|
141 |
+
"""Estimation of the nonnegative mixture weight by a 1-D conv layer.
|
142 |
+
"""
|
143 |
+
def __init__(self, L, N, audio_channels):
|
144 |
+
super(Encoder, self).__init__()
|
145 |
+
# Hyper-parameter
|
146 |
+
self.L, self.N = L, N
|
147 |
+
# Components
|
148 |
+
# 50% overlap
|
149 |
+
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
|
150 |
+
|
151 |
+
def forward(self, mixture):
|
152 |
+
"""
|
153 |
+
Args:
|
154 |
+
mixture: [M, T], M is batch size, T is #samples
|
155 |
+
Returns:
|
156 |
+
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
|
157 |
+
"""
|
158 |
+
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
|
159 |
+
return mixture_w
|
160 |
+
|
161 |
+
|
162 |
+
class Decoder(nn.Module):
|
163 |
+
def __init__(self, N, L, audio_channels):
|
164 |
+
super(Decoder, self).__init__()
|
165 |
+
# Hyper-parameter
|
166 |
+
self.N, self.L = N, L
|
167 |
+
self.audio_channels = audio_channels
|
168 |
+
# Components
|
169 |
+
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
|
170 |
+
|
171 |
+
def forward(self, mixture_w, est_mask):
|
172 |
+
"""
|
173 |
+
Args:
|
174 |
+
mixture_w: [M, N, K]
|
175 |
+
est_mask: [M, C, N, K]
|
176 |
+
Returns:
|
177 |
+
est_source: [M, C, T]
|
178 |
+
"""
|
179 |
+
# D = W * M
|
180 |
+
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
|
181 |
+
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
|
182 |
+
# S = DV
|
183 |
+
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
|
184 |
+
m, c, k, _ = est_source.size()
|
185 |
+
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
|
186 |
+
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
|
187 |
+
return est_source
|
188 |
+
|
189 |
+
|
190 |
+
class TemporalConvNet(nn.Module):
|
191 |
+
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
|
192 |
+
"""
|
193 |
+
Args:
|
194 |
+
N: Number of filters in autoencoder
|
195 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
196 |
+
H: Number of channels in convolutional blocks
|
197 |
+
P: Kernel size in convolutional blocks
|
198 |
+
X: Number of convolutional blocks in each repeat
|
199 |
+
R: Number of repeats
|
200 |
+
C: Number of speakers
|
201 |
+
norm_type: BN, gLN, cLN
|
202 |
+
causal: causal or non-causal
|
203 |
+
mask_nonlinear: use which non-linear function to generate mask
|
204 |
+
"""
|
205 |
+
super(TemporalConvNet, self).__init__()
|
206 |
+
# Hyper-parameter
|
207 |
+
self.C = C
|
208 |
+
self.mask_nonlinear = mask_nonlinear
|
209 |
+
# Components
|
210 |
+
# [M, N, K] -> [M, N, K]
|
211 |
+
layer_norm = ChannelwiseLayerNorm(N)
|
212 |
+
# [M, N, K] -> [M, B, K]
|
213 |
+
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
|
214 |
+
# [M, B, K] -> [M, B, K]
|
215 |
+
repeats = []
|
216 |
+
for r in range(R):
|
217 |
+
blocks = []
|
218 |
+
for x in range(X):
|
219 |
+
dilation = 2**x
|
220 |
+
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
|
221 |
+
blocks += [
|
222 |
+
TemporalBlock(B,
|
223 |
+
H,
|
224 |
+
P,
|
225 |
+
stride=1,
|
226 |
+
padding=padding,
|
227 |
+
dilation=dilation,
|
228 |
+
norm_type=norm_type,
|
229 |
+
causal=causal)
|
230 |
+
]
|
231 |
+
repeats += [nn.Sequential(*blocks)]
|
232 |
+
temporal_conv_net = nn.Sequential(*repeats)
|
233 |
+
# [M, B, K] -> [M, C*N, K]
|
234 |
+
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
|
235 |
+
# Put together
|
236 |
+
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
|
237 |
+
mask_conv1x1)
|
238 |
+
|
239 |
+
def forward(self, mixture_w):
|
240 |
+
"""
|
241 |
+
Keep this API same with TasNet
|
242 |
+
Args:
|
243 |
+
mixture_w: [M, N, K], M is batch size
|
244 |
+
returns:
|
245 |
+
est_mask: [M, C, N, K]
|
246 |
+
"""
|
247 |
+
M, N, K = mixture_w.size()
|
248 |
+
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
|
249 |
+
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
|
250 |
+
if self.mask_nonlinear == 'softmax':
|
251 |
+
est_mask = F.softmax(score, dim=1)
|
252 |
+
elif self.mask_nonlinear == 'relu':
|
253 |
+
est_mask = F.relu(score)
|
254 |
+
else:
|
255 |
+
raise ValueError("Unsupported mask non-linear function")
|
256 |
+
return est_mask
|
257 |
+
|
258 |
+
|
259 |
+
class TemporalBlock(nn.Module):
|
260 |
+
def __init__(self,
|
261 |
+
in_channels,
|
262 |
+
out_channels,
|
263 |
+
kernel_size,
|
264 |
+
stride,
|
265 |
+
padding,
|
266 |
+
dilation,
|
267 |
+
norm_type="gLN",
|
268 |
+
causal=False):
|
269 |
+
super(TemporalBlock, self).__init__()
|
270 |
+
# [M, B, K] -> [M, H, K]
|
271 |
+
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
272 |
+
prelu = nn.PReLU()
|
273 |
+
norm = chose_norm(norm_type, out_channels)
|
274 |
+
# [M, H, K] -> [M, B, K]
|
275 |
+
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
|
276 |
+
dilation, norm_type, causal)
|
277 |
+
# Put together
|
278 |
+
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
"""
|
282 |
+
Args:
|
283 |
+
x: [M, B, K]
|
284 |
+
Returns:
|
285 |
+
[M, B, K]
|
286 |
+
"""
|
287 |
+
residual = x
|
288 |
+
out = self.net(x)
|
289 |
+
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
|
290 |
+
return out + residual # look like w/o F.relu is better than w/ F.relu
|
291 |
+
# return F.relu(out + residual)
|
292 |
+
|
293 |
+
|
294 |
+
class DepthwiseSeparableConv(nn.Module):
|
295 |
+
def __init__(self,
|
296 |
+
in_channels,
|
297 |
+
out_channels,
|
298 |
+
kernel_size,
|
299 |
+
stride,
|
300 |
+
padding,
|
301 |
+
dilation,
|
302 |
+
norm_type="gLN",
|
303 |
+
causal=False):
|
304 |
+
super(DepthwiseSeparableConv, self).__init__()
|
305 |
+
# Use `groups` option to implement depthwise convolution
|
306 |
+
# [M, H, K] -> [M, H, K]
|
307 |
+
depthwise_conv = nn.Conv1d(in_channels,
|
308 |
+
in_channels,
|
309 |
+
kernel_size,
|
310 |
+
stride=stride,
|
311 |
+
padding=padding,
|
312 |
+
dilation=dilation,
|
313 |
+
groups=in_channels,
|
314 |
+
bias=False)
|
315 |
+
if causal:
|
316 |
+
chomp = Chomp1d(padding)
|
317 |
+
prelu = nn.PReLU()
|
318 |
+
norm = chose_norm(norm_type, in_channels)
|
319 |
+
# [M, H, K] -> [M, B, K]
|
320 |
+
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
321 |
+
# Put together
|
322 |
+
if causal:
|
323 |
+
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
|
324 |
+
else:
|
325 |
+
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
|
326 |
+
|
327 |
+
def forward(self, x):
|
328 |
+
"""
|
329 |
+
Args:
|
330 |
+
x: [M, H, K]
|
331 |
+
Returns:
|
332 |
+
result: [M, B, K]
|
333 |
+
"""
|
334 |
+
return self.net(x)
|
335 |
+
|
336 |
+
|
337 |
+
class Chomp1d(nn.Module):
|
338 |
+
"""To ensure the output length is the same as the input.
|
339 |
+
"""
|
340 |
+
def __init__(self, chomp_size):
|
341 |
+
super(Chomp1d, self).__init__()
|
342 |
+
self.chomp_size = chomp_size
|
343 |
+
|
344 |
+
def forward(self, x):
|
345 |
+
"""
|
346 |
+
Args:
|
347 |
+
x: [M, H, Kpad]
|
348 |
+
Returns:
|
349 |
+
[M, H, K]
|
350 |
+
"""
|
351 |
+
return x[:, :, :-self.chomp_size].contiguous()
|
352 |
+
|
353 |
+
|
354 |
+
def chose_norm(norm_type, channel_size):
|
355 |
+
"""The input of normlization will be (M, C, K), where M is batch size,
|
356 |
+
C is channel size and K is sequence length.
|
357 |
+
"""
|
358 |
+
if norm_type == "gLN":
|
359 |
+
return GlobalLayerNorm(channel_size)
|
360 |
+
elif norm_type == "cLN":
|
361 |
+
return ChannelwiseLayerNorm(channel_size)
|
362 |
+
elif norm_type == "id":
|
363 |
+
return nn.Identity()
|
364 |
+
else: # norm_type == "BN":
|
365 |
+
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
366 |
+
# along M and K, so this BN usage is right.
|
367 |
+
return nn.BatchNorm1d(channel_size)
|
368 |
+
|
369 |
+
|
370 |
+
# TODO: Use nn.LayerNorm to impl cLN to speed up
|
371 |
+
class ChannelwiseLayerNorm(nn.Module):
|
372 |
+
"""Channel-wise Layer Normalization (cLN)"""
|
373 |
+
def __init__(self, channel_size):
|
374 |
+
super(ChannelwiseLayerNorm, self).__init__()
|
375 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
376 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
377 |
+
self.reset_parameters()
|
378 |
+
|
379 |
+
def reset_parameters(self):
|
380 |
+
self.gamma.data.fill_(1)
|
381 |
+
self.beta.data.zero_()
|
382 |
+
|
383 |
+
def forward(self, y):
|
384 |
+
"""
|
385 |
+
Args:
|
386 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
387 |
+
Returns:
|
388 |
+
cLN_y: [M, N, K]
|
389 |
+
"""
|
390 |
+
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
|
391 |
+
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
|
392 |
+
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
393 |
+
return cLN_y
|
394 |
+
|
395 |
+
|
396 |
+
class GlobalLayerNorm(nn.Module):
|
397 |
+
"""Global Layer Normalization (gLN)"""
|
398 |
+
def __init__(self, channel_size):
|
399 |
+
super(GlobalLayerNorm, self).__init__()
|
400 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
401 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
402 |
+
self.reset_parameters()
|
403 |
+
|
404 |
+
def reset_parameters(self):
|
405 |
+
self.gamma.data.fill_(1)
|
406 |
+
self.beta.data.zero_()
|
407 |
+
|
408 |
+
def forward(self, y):
|
409 |
+
"""
|
410 |
+
Args:
|
411 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
412 |
+
Returns:
|
413 |
+
gLN_y: [M, N, K]
|
414 |
+
"""
|
415 |
+
# TODO: in torch 1.0, torch.mean() support dim list
|
416 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
|
417 |
+
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
418 |
+
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
419 |
+
return gLN_y
|
420 |
+
|
421 |
+
|
422 |
+
if __name__ == "__main__":
|
423 |
+
torch.manual_seed(123)
|
424 |
+
M, N, L, T = 2, 3, 4, 12
|
425 |
+
K = 2 * T // L - 1
|
426 |
+
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
|
427 |
+
mixture = torch.randint(3, (M, T))
|
428 |
+
# test Encoder
|
429 |
+
encoder = Encoder(L, N)
|
430 |
+
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
|
431 |
+
mixture_w = encoder(mixture)
|
432 |
+
print('mixture', mixture)
|
433 |
+
print('U', encoder.conv1d_U.weight)
|
434 |
+
print('mixture_w', mixture_w)
|
435 |
+
print('mixture_w size', mixture_w.size())
|
436 |
+
|
437 |
+
# test TemporalConvNet
|
438 |
+
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
|
439 |
+
est_mask = separator(mixture_w)
|
440 |
+
print('est_mask', est_mask)
|
441 |
+
|
442 |
+
# test Decoder
|
443 |
+
decoder = Decoder(N, L)
|
444 |
+
est_mask = torch.randint(2, (B, K, C, N))
|
445 |
+
est_source = decoder(mixture_w, est_mask)
|
446 |
+
print('est_source', est_source)
|
447 |
+
|
448 |
+
# test Conv-TasNet
|
449 |
+
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
|
450 |
+
est_source = conv_tasnet(mixture)
|
451 |
+
print('est_source', est_source)
|
452 |
+
print('est_source size', est_source.size())
|
demucs/test.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import gzip
|
8 |
+
import sys
|
9 |
+
from concurrent import futures
|
10 |
+
|
11 |
+
import musdb
|
12 |
+
import museval
|
13 |
+
import torch as th
|
14 |
+
import tqdm
|
15 |
+
from scipy.io import wavfile
|
16 |
+
from torch import distributed
|
17 |
+
|
18 |
+
from .audio import convert_audio
|
19 |
+
from .utils import apply_model
|
20 |
+
|
21 |
+
|
22 |
+
def evaluate(model,
|
23 |
+
musdb_path,
|
24 |
+
eval_folder,
|
25 |
+
workers=2,
|
26 |
+
device="cpu",
|
27 |
+
rank=0,
|
28 |
+
save=False,
|
29 |
+
shifts=0,
|
30 |
+
split=False,
|
31 |
+
overlap=0.25,
|
32 |
+
is_wav=False,
|
33 |
+
world_size=1):
|
34 |
+
"""
|
35 |
+
Evaluate model using museval. Run the model
|
36 |
+
on a single GPU, the bottleneck being the call to museval.
|
37 |
+
"""
|
38 |
+
|
39 |
+
output_dir = eval_folder / "results"
|
40 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
41 |
+
json_folder = eval_folder / "results/test"
|
42 |
+
json_folder.mkdir(exist_ok=True, parents=True)
|
43 |
+
|
44 |
+
# we load tracks from the original musdb set
|
45 |
+
test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav)
|
46 |
+
src_rate = 44100 # hardcoded for now...
|
47 |
+
|
48 |
+
for p in model.parameters():
|
49 |
+
p.requires_grad = False
|
50 |
+
p.grad = None
|
51 |
+
|
52 |
+
pendings = []
|
53 |
+
with futures.ProcessPoolExecutor(workers or 1) as pool:
|
54 |
+
for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout):
|
55 |
+
track = test_set.tracks[index]
|
56 |
+
|
57 |
+
out = json_folder / f"{track.name}.json.gz"
|
58 |
+
if out.exists():
|
59 |
+
continue
|
60 |
+
|
61 |
+
mix = th.from_numpy(track.audio).t().float()
|
62 |
+
ref = mix.mean(dim=0) # mono mixture
|
63 |
+
mix = (mix - ref.mean()) / ref.std()
|
64 |
+
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
|
65 |
+
estimates = apply_model(model, mix.to(device),
|
66 |
+
shifts=shifts, split=split, overlap=overlap)
|
67 |
+
estimates = estimates * ref.std() + ref.mean()
|
68 |
+
|
69 |
+
estimates = estimates.transpose(1, 2)
|
70 |
+
references = th.stack(
|
71 |
+
[th.from_numpy(track.targets[name].audio).t() for name in model.sources])
|
72 |
+
references = convert_audio(references, src_rate,
|
73 |
+
model.samplerate, model.audio_channels)
|
74 |
+
references = references.transpose(1, 2).numpy()
|
75 |
+
estimates = estimates.cpu().numpy()
|
76 |
+
win = int(1. * model.samplerate)
|
77 |
+
hop = int(1. * model.samplerate)
|
78 |
+
if save:
|
79 |
+
folder = eval_folder / "wav/test" / track.name
|
80 |
+
folder.mkdir(exist_ok=True, parents=True)
|
81 |
+
for name, estimate in zip(model.sources, estimates):
|
82 |
+
wavfile.write(str(folder / (name + ".wav")), 44100, estimate)
|
83 |
+
|
84 |
+
if workers:
|
85 |
+
pendings.append((track.name, pool.submit(
|
86 |
+
museval.evaluate, references, estimates, win=win, hop=hop)))
|
87 |
+
else:
|
88 |
+
pendings.append((track.name, museval.evaluate(
|
89 |
+
references, estimates, win=win, hop=hop)))
|
90 |
+
del references, mix, estimates, track
|
91 |
+
|
92 |
+
for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout):
|
93 |
+
if workers:
|
94 |
+
pending = pending.result()
|
95 |
+
sdr, isr, sir, sar = pending
|
96 |
+
track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name)
|
97 |
+
for idx, target in enumerate(model.sources):
|
98 |
+
values = {
|
99 |
+
"SDR": sdr[idx].tolist(),
|
100 |
+
"SIR": sir[idx].tolist(),
|
101 |
+
"ISR": isr[idx].tolist(),
|
102 |
+
"SAR": sar[idx].tolist()
|
103 |
+
}
|
104 |
+
|
105 |
+
track_store.add_target(target_name=target, values=values)
|
106 |
+
json_path = json_folder / f"{track_name}.json.gz"
|
107 |
+
gzip.open(json_path, "w").write(track_store.json.encode('utf-8'))
|
108 |
+
if world_size > 1:
|
109 |
+
distributed.barrier()
|
demucs/train.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import sys
|
8 |
+
|
9 |
+
import tqdm
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
|
13 |
+
from .utils import apply_model, average_metric, center_trim
|
14 |
+
|
15 |
+
|
16 |
+
def train_model(epoch,
|
17 |
+
dataset,
|
18 |
+
model,
|
19 |
+
criterion,
|
20 |
+
optimizer,
|
21 |
+
augment,
|
22 |
+
quantizer=None,
|
23 |
+
diffq=0,
|
24 |
+
repeat=1,
|
25 |
+
device="cpu",
|
26 |
+
seed=None,
|
27 |
+
workers=4,
|
28 |
+
world_size=1,
|
29 |
+
batch_size=16):
|
30 |
+
|
31 |
+
if world_size > 1:
|
32 |
+
sampler = DistributedSampler(dataset)
|
33 |
+
sampler_epoch = epoch * repeat
|
34 |
+
if seed is not None:
|
35 |
+
sampler_epoch += seed * 1000
|
36 |
+
sampler.set_epoch(sampler_epoch)
|
37 |
+
batch_size //= world_size
|
38 |
+
loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=workers)
|
39 |
+
else:
|
40 |
+
loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=True)
|
41 |
+
current_loss = 0
|
42 |
+
model_size = 0
|
43 |
+
for repetition in range(repeat):
|
44 |
+
tq = tqdm.tqdm(loader,
|
45 |
+
ncols=120,
|
46 |
+
desc=f"[{epoch:03d}] train ({repetition + 1}/{repeat})",
|
47 |
+
leave=False,
|
48 |
+
file=sys.stdout,
|
49 |
+
unit=" batch")
|
50 |
+
total_loss = 0
|
51 |
+
for idx, sources in enumerate(tq):
|
52 |
+
if len(sources) < batch_size:
|
53 |
+
# skip uncomplete batch for augment.Remix to work properly
|
54 |
+
continue
|
55 |
+
sources = sources.to(device)
|
56 |
+
sources = augment(sources)
|
57 |
+
mix = sources.sum(dim=1)
|
58 |
+
|
59 |
+
estimates = model(mix)
|
60 |
+
sources = center_trim(sources, estimates)
|
61 |
+
loss = criterion(estimates, sources)
|
62 |
+
model_size = 0
|
63 |
+
if quantizer is not None:
|
64 |
+
model_size = quantizer.model_size()
|
65 |
+
|
66 |
+
train_loss = loss + diffq * model_size
|
67 |
+
train_loss.backward()
|
68 |
+
grad_norm = 0
|
69 |
+
for p in model.parameters():
|
70 |
+
if p.grad is not None:
|
71 |
+
grad_norm += p.grad.data.norm()**2
|
72 |
+
grad_norm = grad_norm**0.5
|
73 |
+
optimizer.step()
|
74 |
+
optimizer.zero_grad()
|
75 |
+
|
76 |
+
if quantizer is not None:
|
77 |
+
model_size = model_size.item()
|
78 |
+
|
79 |
+
total_loss += loss.item()
|
80 |
+
current_loss = total_loss / (1 + idx)
|
81 |
+
tq.set_postfix(loss=f"{current_loss:.4f}", ms=f"{model_size:.2f}",
|
82 |
+
grad=f"{grad_norm:.5f}")
|
83 |
+
|
84 |
+
# free some space before next round
|
85 |
+
del sources, mix, estimates, loss, train_loss
|
86 |
+
|
87 |
+
if world_size > 1:
|
88 |
+
sampler.epoch += 1
|
89 |
+
|
90 |
+
if world_size > 1:
|
91 |
+
current_loss = average_metric(current_loss)
|
92 |
+
return current_loss, model_size
|
93 |
+
|
94 |
+
|
95 |
+
def validate_model(epoch,
|
96 |
+
dataset,
|
97 |
+
model,
|
98 |
+
criterion,
|
99 |
+
device="cpu",
|
100 |
+
rank=0,
|
101 |
+
world_size=1,
|
102 |
+
shifts=0,
|
103 |
+
overlap=0.25,
|
104 |
+
split=False):
|
105 |
+
indexes = range(rank, len(dataset), world_size)
|
106 |
+
tq = tqdm.tqdm(indexes,
|
107 |
+
ncols=120,
|
108 |
+
desc=f"[{epoch:03d}] valid",
|
109 |
+
leave=False,
|
110 |
+
file=sys.stdout,
|
111 |
+
unit=" track")
|
112 |
+
current_loss = 0
|
113 |
+
for index in tq:
|
114 |
+
streams = dataset[index]
|
115 |
+
# first five minutes to avoid OOM on --upsample models
|
116 |
+
streams = streams[..., :15_000_000]
|
117 |
+
streams = streams.to(device)
|
118 |
+
sources = streams[1:]
|
119 |
+
mix = streams[0]
|
120 |
+
estimates = apply_model(model, mix, shifts=shifts, split=split, overlap=overlap)
|
121 |
+
loss = criterion(estimates, sources)
|
122 |
+
current_loss += loss.item() / len(indexes)
|
123 |
+
del estimates, streams, sources
|
124 |
+
|
125 |
+
if world_size > 1:
|
126 |
+
current_loss = average_metric(current_loss, len(indexes))
|
127 |
+
return current_loss
|