init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- README2.md +256 -0
- UVR.py +0 -0
- __version__.py +4 -0
- demucs/__init__.py +5 -0
- demucs/__main__.py +272 -0
- demucs/__pycache__/__init__.cpython-310.pyc +0 -0
- demucs/__pycache__/apply.cpython-310.pyc +0 -0
- demucs/__pycache__/demucs.cpython-310.pyc +0 -0
- demucs/__pycache__/filtering.cpython-310.pyc +0 -0
- demucs/__pycache__/hdemucs.cpython-310.pyc +0 -0
- demucs/__pycache__/model.cpython-310.pyc +0 -0
- demucs/__pycache__/model_v2.cpython-310.pyc +0 -0
- demucs/__pycache__/pretrained.cpython-310.pyc +0 -0
- demucs/__pycache__/repo.cpython-310.pyc +0 -0
- demucs/__pycache__/spec.cpython-310.pyc +0 -0
- demucs/__pycache__/states.cpython-310.pyc +0 -0
- demucs/__pycache__/tasnet_v2.cpython-310.pyc +0 -0
- demucs/__pycache__/utils.cpython-310.pyc +0 -0
- demucs/apply.py +305 -0
- demucs/demucs.py +459 -0
- demucs/filtering.py +502 -0
- demucs/hdemucs.py +796 -0
- demucs/htdemucs.py +664 -0
- demucs/model.py +218 -0
- demucs/model_v2.py +218 -0
- demucs/pretrained.py +180 -0
- demucs/repo.py +148 -0
- demucs/spec.py +53 -0
- demucs/states.py +148 -0
- demucs/tasnet.py +447 -0
- demucs/tasnet_v2.py +452 -0
- demucs/transformer.py +839 -0
- demucs/utils.py +502 -0
- gui_data/__pycache__/app_size_values.cpython-310.pyc +0 -0
- gui_data/__pycache__/constants.cpython-310.pyc +0 -0
- gui_data/__pycache__/error_handling.cpython-310.pyc +0 -0
- gui_data/__pycache__/old_data_check.cpython-310.pyc +0 -0
- gui_data/app_size_values.py +371 -0
- gui_data/change_log.txt +93 -0
- gui_data/complete_chime.wav +0 -0
- gui_data/constants.py +1584 -0
- gui_data/cr_text.txt +104 -0
- gui_data/error_handling.py +110 -0
- gui_data/fail_chime.wav +0 -0
- gui_data/fonts/Montserrat/Montserrat.ttf +0 -0
- gui_data/fonts/centurygothic/GOTHIC.ttf +0 -0
- gui_data/fonts/other/own_font_goes_here.txt +1 -0
- gui_data/img/File.png +0 -0
- gui_data/img/GUI-Icon.ico +0 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
.idea/
|
README2.md
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultimate Vocal Remover GUI v5.6
|
2 |
+
<img src="https://raw.githubusercontent.com/Anjok07/ultimatevocalremovergui/master/gui_data/img/UVR_v5.6.png?raw=true" />
|
3 |
+
|
4 |
+
[![Release](https://img.shields.io/github/release/anjok07/ultimatevocalremovergui.svg)](https://github.com/anjok07/ultimatevocalremovergui/releases/latest)
|
5 |
+
[![Downloads](https://img.shields.io/github/downloads/anjok07/ultimatevocalremovergui/total.svg)](https://github.com/anjok07/ultimatevocalremovergui/releases)
|
6 |
+
|
7 |
+
## About
|
8 |
+
|
9 |
+
This application uses state-of-the-art source separation models to remove vocals from audio files. UVR's core developers trained all of the models provided in this package (except for the Demucs v3 and v4 4-stem models).
|
10 |
+
|
11 |
+
- **Core Developers**
|
12 |
+
- [Anjok07](https://github.com/anjok07)
|
13 |
+
- [aufr33](https://github.com/aufr33)
|
14 |
+
|
15 |
+
- **Support the Project**
|
16 |
+
- [Donate](https://www.buymeacoffee.com/uvr5)
|
17 |
+
|
18 |
+
## Installation
|
19 |
+
|
20 |
+
These bundles contain the UVR interface, Python, PyTorch, and other dependencies needed to run the application effectively. No prerequisites are required.
|
21 |
+
|
22 |
+
### Windows Installation
|
23 |
+
|
24 |
+
- Please Note:
|
25 |
+
- This installer is intended for those running Windows 10 or higher.
|
26 |
+
- Application functionality for systems running Windows 7 or lower is not guaranteed.
|
27 |
+
- Application functionality for Intel Pentium & Celeron CPUs systems is not guaranteed.
|
28 |
+
- You must install UVR to the main C:\ drive. Installing UVR to a secondary drive will cause instability.
|
29 |
+
|
30 |
+
- Download the UVR installer for Windows via the link below:
|
31 |
+
- [Main Download Link](https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/UVR_v5.6.0_setup.exe)
|
32 |
+
- [Main Download Link mirror](https://www.mediafire.com/file_premium/jiatpgp0ljou52p/UVR_v5.6.0_setup.exe/file)
|
33 |
+
- If you use an **AMD Radeon or Intel Arc graphics card**, you can try the OpenCL version:
|
34 |
+
- [OpenCL Version - Main Download Link](https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/UVR_v5.6.0_setup_opencl.exe)
|
35 |
+
- Update Package instructions for those who have UVR already installed:
|
36 |
+
- If you already have UVR installed you can install this package over it or download it straight from the application or [click here for the patch](https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/UVR_Patch_10_6_23_4_27.exe).
|
37 |
+
|
38 |
+
<details id="WindowsManual">
|
39 |
+
<summary>Windows Manual Installation</summary>
|
40 |
+
|
41 |
+
### Manual Windows Installation
|
42 |
+
|
43 |
+
- Download and extract the repository [here](https://github.com/Anjok07/ultimatevocalremovergui/archive/refs/heads/master.zip)
|
44 |
+
- Download and install Python [here](https://www.python.org/ftp/python/3.9.8/python-3.9.8-amd64.exe)
|
45 |
+
- Make sure to check "Add python.exe to PATH" during the install
|
46 |
+
- Run the following commands from the extracted repo directory:
|
47 |
+
|
48 |
+
```
|
49 |
+
python.exe -m pip install -r requirements.txt
|
50 |
+
```
|
51 |
+
|
52 |
+
If you have a compatible Nvidia GPU, run the following command:
|
53 |
+
|
54 |
+
```
|
55 |
+
python.exe -m pip install --upgrade torch --extra-index-url https://download.pytorch.org/whl/cu117
|
56 |
+
```
|
57 |
+
|
58 |
+
If you do not have FFmpeg or Rubber Band installed and want to avoid going through the process of installing them the long way, follow the instructions below.
|
59 |
+
|
60 |
+
**FFmpeg Installation**
|
61 |
+
|
62 |
+
- Download the precompiled build [here](https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-essentials.zip)
|
63 |
+
- From the archive, extract the following file to the UVR application directory:
|
64 |
+
- ```ffmpeg-5.1.2-essentials_build/bin/ffmpeg.exe```
|
65 |
+
|
66 |
+
**Rubber Band Installation**
|
67 |
+
|
68 |
+
In order to use the Time Stretch or Change Pitch tool, you'll need Rubber Band.
|
69 |
+
|
70 |
+
- Download the precompiled build [here](https://breakfastquay.com/files/releases/rubberband-3.1.2-gpl-executable-windows.zip)
|
71 |
+
- From the archive, extract the following files to the UVR application directory:
|
72 |
+
- ```rubberband-3.1.2-gpl-executable-windows/rubberband.exe```
|
73 |
+
- ```rubberband-3.1.2-gpl-executable-windows/sndfile.dll```
|
74 |
+
|
75 |
+
</details>
|
76 |
+
|
77 |
+
### MacOS Installation
|
78 |
+
- Please Note:
|
79 |
+
- The MacOS Sonoma mouse clicking issue has been fixed.
|
80 |
+
- MPS (GPU) acceleration for Mac M1 has been expanded to work with Demucs v4 and all MDX-Net models.
|
81 |
+
- This bundle is intended for those running macOS Big Sur and above.
|
82 |
+
- Application functionality for systems running macOS Catalina or lower is not guaranteed.
|
83 |
+
- Application functionality for older or budget Mac systems is not guaranteed.
|
84 |
+
- Once everything is installed, the application may take up to 5-10 minutes to start for the first time (depending on your Macbook).
|
85 |
+
|
86 |
+
- Download the UVR dmg for MacOS via one of the links below:
|
87 |
+
- Mac M1 (arm64) users:
|
88 |
+
- [Main Download Link](https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/Ultimate_Vocal_Remover_v5_6_MacOS_arm64.dmg)
|
89 |
+
- [Main Download Link mirror](https://www.mediafire.com/file_premium/u3rk54wsqadpy93/Ultimate_Vocal_Remover_v5_6_MacOS_arm64.dmg/file)
|
90 |
+
|
91 |
+
- Mac Intel (x86_64) users:
|
92 |
+
- [Main Download Link](https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/Ultimate_Vocal_Remover_v5_6_MacOS_x86_64.dmg)
|
93 |
+
- [Main Download Link mirror](https://www.mediafire.com/file_premium/2gf1werx5ly5ylz/Ultimate_Vocal_Remover_v5_6_MacOS_x86_64.dmg/file)
|
94 |
+
|
95 |
+
<details id="CannotOpen">
|
96 |
+
<summary>MacOS Users: Having Trouble Opening UVR?</summary>
|
97 |
+
|
98 |
+
> Due to Apples strict application security, you may need to follow these steps to open UVR.
|
99 |
+
>
|
100 |
+
> First, run the following command via Terminal.app to allow applications to run from all sources (it's recommended that you re-enable this once UVR opens properly.)
|
101 |
+
>
|
102 |
+
> ```bash
|
103 |
+
> sudo spctl --master-disable
|
104 |
+
> ```
|
105 |
+
>
|
106 |
+
> Second, run the following command to bypass Notarization:
|
107 |
+
>
|
108 |
+
> ```bash
|
109 |
+
> sudo xattr -rd com.apple.quarantine /Applications/Ultimate\ Vocal\ Remover.app
|
110 |
+
> ```
|
111 |
+
|
112 |
+
</details>
|
113 |
+
|
114 |
+
<details id="MacInstall">
|
115 |
+
<summary>Manual MacOS Installation</summary>
|
116 |
+
|
117 |
+
### Manual MacOS Installation
|
118 |
+
|
119 |
+
- Download and save this repository [here](https://github.com/Anjok07/ultimatevocalremovergui/archive/refs/heads/master.zip)
|
120 |
+
- Download and install Python 3.10 [here](https://www.python.org/ftp/python/3.10.9/python-3.10.9-macos11.pkg)
|
121 |
+
- From the saved directory run the following -
|
122 |
+
|
123 |
+
```
|
124 |
+
pip3 install -r requirements.txt
|
125 |
+
```
|
126 |
+
|
127 |
+
- If your Mac is running with an M1, please run the following command next. If not, skip this step. -
|
128 |
+
|
129 |
+
```
|
130 |
+
cp /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/_soundfile_data/libsndfile_arm64.dylib /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/_soundfile_data/libsndfile.dylib
|
131 |
+
```
|
132 |
+
|
133 |
+
**FFmpeg Installation**
|
134 |
+
|
135 |
+
- Once everything is done installing, download the correct FFmpeg binary for your system [here](http://www.osxexperts.net) and place it into the main application directory.
|
136 |
+
|
137 |
+
**Rubber Band Installation**
|
138 |
+
|
139 |
+
In order to use the Time Stretch or Change Pitch tool, you'll need Rubber Band.
|
140 |
+
|
141 |
+
- Download the precompiled build [here](https://breakfastquay.com/files/releases/rubberband-3.1.2-gpl-executable-windows.zip)
|
142 |
+
- From the archive, extract the following files to the UVR/lib_v5 application directory:
|
143 |
+
- ```rubberband-3.1.2-gpl-executable-macos/rubberband```
|
144 |
+
|
145 |
+
This process has been tested on a MacBook Pro 2021 (using M1) and a MacBook Air 2017 and is confirmed to be working on both.
|
146 |
+
|
147 |
+
</details>
|
148 |
+
|
149 |
+
### Linux Installation
|
150 |
+
|
151 |
+
<details id="LinuxInstall">
|
152 |
+
<summary>See Linux Installation Instructions</summary>
|
153 |
+
|
154 |
+
<br />
|
155 |
+
|
156 |
+
**These install instructions are for Debian & Arch based Linux systems.**
|
157 |
+
|
158 |
+
- Download and save this repository [here](https://github.com/Anjok07/ultimatevocalremovergui/archive/refs/heads/master.zip)
|
159 |
+
- From the saved directory run the following commands in this order-
|
160 |
+
|
161 |
+
**For Debian Based (Ubuntu, Mint, etc.):**
|
162 |
+
```
|
163 |
+
sudo apt update && sudo apt upgrade
|
164 |
+
sudo apt-get update
|
165 |
+
sudo apt install ffmpeg
|
166 |
+
sudo apt install python3-pip
|
167 |
+
sudo apt-get -y install python3-tk
|
168 |
+
pip3 install -r requirements.txt
|
169 |
+
python3 UVR.py
|
170 |
+
```
|
171 |
+
|
172 |
+
**For Arch Based (EndeavourOS):**
|
173 |
+
```
|
174 |
+
sudo pacman -Syu
|
175 |
+
sudo pacman -Sy
|
176 |
+
sudo pacman -S python-pip
|
177 |
+
sudo pacman -S --noconfirm tk
|
178 |
+
sudo pacman -S ffmpeg
|
179 |
+
```
|
180 |
+
|
181 |
+
To bypass environment setup and proceed with the installation, use:
|
182 |
+
|
183 |
+
- Take caution; this modifies system files.
|
184 |
+
|
185 |
+
```
|
186 |
+
sudo rm /usr/lib/python3.11/EXTERNALLY-MANAGED
|
187 |
+
```
|
188 |
+
|
189 |
+
Then proceed with the following in order:
|
190 |
+
|
191 |
+
```
|
192 |
+
chmod +x install_packages.sh
|
193 |
+
./install_packages.sh
|
194 |
+
python UVR.py
|
195 |
+
```
|
196 |
+
|
197 |
+
</details>
|
198 |
+
|
199 |
+
### Other Application Notes
|
200 |
+
- Nvidia RTX 1060 6GB is the minimum requirement for GPU conversions.
|
201 |
+
- Nvidia GPUs with at least 8GBs of V-RAM are recommended.
|
202 |
+
- AMD Radeon GPU supported is limited at this time.
|
203 |
+
- There is currently a working branch for AMD GPU users [here](https://github.com/Anjok07/ultimatevocalremovergui/tree/v5.6-amd-gpu)
|
204 |
+
- This application is only compatible with 64-bit platforms.
|
205 |
+
- This application relies on the Rubber Band library for the Time-Stretch and Pitch-Shift options.
|
206 |
+
- This application relies on FFmpeg to process non-wav audio files.
|
207 |
+
- The application will automatically remember your settings when closed.
|
208 |
+
- Conversion times will significantly depend on your hardware.
|
209 |
+
- These models are computationally intensive.
|
210 |
+
|
211 |
+
### Performance:
|
212 |
+
- Model load times are faster.
|
213 |
+
- Importing/exporting audio files is faster.
|
214 |
+
|
215 |
+
## Troubleshooting
|
216 |
+
|
217 |
+
### Common Issues
|
218 |
+
|
219 |
+
- If FFmpeg is not installed, the application will throw an error if the user attempts to convert a non-WAV file.
|
220 |
+
- Memory allocation errors can usually be resolved by lowering the "Segment" or "Window" sizes.
|
221 |
+
|
222 |
+
#### MacOS Sonoma Left-click Bug
|
223 |
+
There's a known issue on MacOS Sonoma where left-clicks aren't registering correctly within the app. This was impacting all applications built with Tkinter on Sonoma and has since been resolved. Please download the latest version via the following link if you are still experiencing issues - [link](https://github.com/Anjok07/ultimatevocalremovergui/releases/tag/v5.6)
|
224 |
+
|
225 |
+
This issue was being tracked [here](https://github.com/Anjok07/ultimatevocalremovergui/issues/840).
|
226 |
+
|
227 |
+
### Issue Reporting
|
228 |
+
|
229 |
+
Please be as detailed as possible when posting a new issue.
|
230 |
+
|
231 |
+
If possible, click the "Settings Button" to the left of the "Start Processing" button and click the "Error Log" button for detailed error information that can be provided to us.
|
232 |
+
|
233 |
+
## License
|
234 |
+
|
235 |
+
The **Ultimate Vocal Remover GUI** code is [MIT-licensed](LICENSE).
|
236 |
+
|
237 |
+
- **Please Note:** For all third-party application developers who wish to use our models, please honor the MIT license by providing credit to UVR and its developers.
|
238 |
+
|
239 |
+
## Credits
|
240 |
+
- [ZFTurbo](https://github.com/ZFTurbo) - Created & trained the weights for the new MDX23C models.
|
241 |
+
- [DilanBoskan](https://github.com/DilanBoskan) - Your contributions at the start of this project were essential to the success of UVR. Thank you!
|
242 |
+
- [Bas Curtiz](https://www.youtube.com/user/bascurtiz) - Designed the official UVR logo, icon, banner, and splash screen.
|
243 |
+
- [tsurumeso](https://github.com/tsurumeso) - Developed the original VR Architecture code.
|
244 |
+
- [Kuielab & Woosung Choi](https://github.com/kuielab) - Developed the original MDX-Net AI code.
|
245 |
+
- [Adefossez & Demucs](https://github.com/facebookresearch/demucs) - Developed the original Demucs AI code.
|
246 |
+
- [KimberleyJSN](https://github.com/KimberleyJensen) - Advised and aided the implementation of the training scripts for MDX-Net and Demucs. Thank you!
|
247 |
+
- [Hv](https://github.com/NaJeongMo/Colab-for-MDX_B) - Helped implement chunks into the MDX-Net AI code. Thank you!
|
248 |
+
|
249 |
+
## Contributing
|
250 |
+
|
251 |
+
- For anyone interested in the ongoing development of **Ultimate Vocal Remover GUI**, please send us a pull request, and we will review it.
|
252 |
+
- This project is 100% open-source and free for anyone to use and modify as they wish.
|
253 |
+
- We only maintain the development and support for the **Ultimate Vocal Remover GUI** and the models provided.
|
254 |
+
|
255 |
+
## References
|
256 |
+
- [1] Takahashi et al., "Multi-scale Multi-band DenseNets for Audio Source Separation", https://arxiv.org/pdf/1706.09588.pdf
|
UVR.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
__version__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
VERSION = 'v5.6.0'
|
2 |
+
PATCH = 'UVR_Patch_9_29_23_1_39'
|
3 |
+
PATCH_MAC = 'UVR_Patch_9_29_23_1_39'
|
4 |
+
PATCH_LINUX = 'UVR_Patch_9_29_23_1_39'
|
demucs/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
demucs/__main__.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
from dataclasses import dataclass, field
|
12 |
+
from fractions import Fraction
|
13 |
+
|
14 |
+
import torch as th
|
15 |
+
from torch import distributed, nn
|
16 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
17 |
+
|
18 |
+
from .augment import FlipChannels, FlipSign, Remix, Shift
|
19 |
+
from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks
|
20 |
+
from .model import Demucs
|
21 |
+
from .parser import get_name, get_parser
|
22 |
+
from .raw import Rawset
|
23 |
+
from .tasnet import ConvTasNet
|
24 |
+
from .test import evaluate
|
25 |
+
from .train import train_model, validate_model
|
26 |
+
from .utils import human_seconds, load_model, save_model, sizeof_fmt
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class SavedState:
|
31 |
+
metrics: list = field(default_factory=list)
|
32 |
+
last_state: dict = None
|
33 |
+
best_state: dict = None
|
34 |
+
optimizer: dict = None
|
35 |
+
|
36 |
+
|
37 |
+
def main():
|
38 |
+
parser = get_parser()
|
39 |
+
args = parser.parse_args()
|
40 |
+
name = get_name(parser, args)
|
41 |
+
print(f"Experiment {name}")
|
42 |
+
|
43 |
+
if args.musdb is None and args.rank == 0:
|
44 |
+
print(
|
45 |
+
"You must provide the path to the MusDB dataset with the --musdb flag. "
|
46 |
+
"To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
|
47 |
+
file=sys.stderr)
|
48 |
+
sys.exit(1)
|
49 |
+
|
50 |
+
eval_folder = args.evals / name
|
51 |
+
eval_folder.mkdir(exist_ok=True, parents=True)
|
52 |
+
args.logs.mkdir(exist_ok=True)
|
53 |
+
metrics_path = args.logs / f"{name}.json"
|
54 |
+
eval_folder.mkdir(exist_ok=True, parents=True)
|
55 |
+
args.checkpoints.mkdir(exist_ok=True, parents=True)
|
56 |
+
args.models.mkdir(exist_ok=True, parents=True)
|
57 |
+
|
58 |
+
if args.device is None:
|
59 |
+
device = "cpu"
|
60 |
+
if th.cuda.is_available():
|
61 |
+
device = "cuda"
|
62 |
+
else:
|
63 |
+
device = args.device
|
64 |
+
|
65 |
+
th.manual_seed(args.seed)
|
66 |
+
# Prevents too many threads to be started when running `museval` as it can be quite
|
67 |
+
# inefficient on NUMA architectures.
|
68 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
69 |
+
|
70 |
+
if args.world_size > 1:
|
71 |
+
if device != "cuda" and args.rank == 0:
|
72 |
+
print("Error: distributed training is only available with cuda device", file=sys.stderr)
|
73 |
+
sys.exit(1)
|
74 |
+
th.cuda.set_device(args.rank % th.cuda.device_count())
|
75 |
+
distributed.init_process_group(backend="nccl",
|
76 |
+
init_method="tcp://" + args.master,
|
77 |
+
rank=args.rank,
|
78 |
+
world_size=args.world_size)
|
79 |
+
|
80 |
+
checkpoint = args.checkpoints / f"{name}.th"
|
81 |
+
checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
|
82 |
+
if args.restart and checkpoint.exists():
|
83 |
+
checkpoint.unlink()
|
84 |
+
|
85 |
+
if args.test:
|
86 |
+
args.epochs = 1
|
87 |
+
args.repeat = 0
|
88 |
+
model = load_model(args.models / args.test)
|
89 |
+
elif args.tasnet:
|
90 |
+
model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X)
|
91 |
+
else:
|
92 |
+
model = Demucs(
|
93 |
+
audio_channels=args.audio_channels,
|
94 |
+
channels=args.channels,
|
95 |
+
context=args.context,
|
96 |
+
depth=args.depth,
|
97 |
+
glu=args.glu,
|
98 |
+
growth=args.growth,
|
99 |
+
kernel_size=args.kernel_size,
|
100 |
+
lstm_layers=args.lstm_layers,
|
101 |
+
rescale=args.rescale,
|
102 |
+
rewrite=args.rewrite,
|
103 |
+
sources=4,
|
104 |
+
stride=args.conv_stride,
|
105 |
+
upsample=args.upsample,
|
106 |
+
samplerate=args.samplerate
|
107 |
+
)
|
108 |
+
model.to(device)
|
109 |
+
if args.show:
|
110 |
+
print(model)
|
111 |
+
size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
|
112 |
+
print(f"Model size {size}")
|
113 |
+
return
|
114 |
+
|
115 |
+
optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
|
116 |
+
|
117 |
+
try:
|
118 |
+
saved = th.load(checkpoint, map_location='cpu')
|
119 |
+
except IOError:
|
120 |
+
saved = SavedState()
|
121 |
+
else:
|
122 |
+
model.load_state_dict(saved.last_state)
|
123 |
+
optimizer.load_state_dict(saved.optimizer)
|
124 |
+
|
125 |
+
if args.save_model:
|
126 |
+
if args.rank == 0:
|
127 |
+
model.to("cpu")
|
128 |
+
model.load_state_dict(saved.best_state)
|
129 |
+
save_model(model, args.models / f"{name}.th")
|
130 |
+
return
|
131 |
+
|
132 |
+
if args.rank == 0:
|
133 |
+
done = args.logs / f"{name}.done"
|
134 |
+
if done.exists():
|
135 |
+
done.unlink()
|
136 |
+
|
137 |
+
if args.augment:
|
138 |
+
augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride),
|
139 |
+
Remix(group_size=args.remix_group_size)).to(device)
|
140 |
+
else:
|
141 |
+
augment = Shift(args.data_stride)
|
142 |
+
|
143 |
+
if args.mse:
|
144 |
+
criterion = nn.MSELoss()
|
145 |
+
else:
|
146 |
+
criterion = nn.L1Loss()
|
147 |
+
|
148 |
+
# Setting number of samples so that all convolution windows are full.
|
149 |
+
# Prevents hard to debug mistake with the prediction being shifted compared
|
150 |
+
# to the input mixture.
|
151 |
+
samples = model.valid_length(args.samples)
|
152 |
+
print(f"Number of training samples adjusted to {samples}")
|
153 |
+
|
154 |
+
if args.raw:
|
155 |
+
train_set = Rawset(args.raw / "train",
|
156 |
+
samples=samples + args.data_stride,
|
157 |
+
channels=args.audio_channels,
|
158 |
+
streams=[0, 1, 2, 3, 4],
|
159 |
+
stride=args.data_stride)
|
160 |
+
|
161 |
+
valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
|
162 |
+
else:
|
163 |
+
if not args.metadata.is_file() and args.rank == 0:
|
164 |
+
build_musdb_metadata(args.metadata, args.musdb, args.workers)
|
165 |
+
if args.world_size > 1:
|
166 |
+
distributed.barrier()
|
167 |
+
metadata = json.load(open(args.metadata))
|
168 |
+
duration = Fraction(samples + args.data_stride, args.samplerate)
|
169 |
+
stride = Fraction(args.data_stride, args.samplerate)
|
170 |
+
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
|
171 |
+
metadata,
|
172 |
+
duration=duration,
|
173 |
+
stride=stride,
|
174 |
+
samplerate=args.samplerate,
|
175 |
+
channels=args.audio_channels)
|
176 |
+
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
|
177 |
+
metadata,
|
178 |
+
samplerate=args.samplerate,
|
179 |
+
channels=args.audio_channels)
|
180 |
+
|
181 |
+
best_loss = float("inf")
|
182 |
+
for epoch, metrics in enumerate(saved.metrics):
|
183 |
+
print(f"Epoch {epoch:03d}: "
|
184 |
+
f"train={metrics['train']:.8f} "
|
185 |
+
f"valid={metrics['valid']:.8f} "
|
186 |
+
f"best={metrics['best']:.4f} "
|
187 |
+
f"duration={human_seconds(metrics['duration'])}")
|
188 |
+
best_loss = metrics['best']
|
189 |
+
|
190 |
+
if args.world_size > 1:
|
191 |
+
dmodel = DistributedDataParallel(model,
|
192 |
+
device_ids=[th.cuda.current_device()],
|
193 |
+
output_device=th.cuda.current_device())
|
194 |
+
else:
|
195 |
+
dmodel = model
|
196 |
+
|
197 |
+
for epoch in range(len(saved.metrics), args.epochs):
|
198 |
+
begin = time.time()
|
199 |
+
model.train()
|
200 |
+
train_loss = train_model(epoch,
|
201 |
+
train_set,
|
202 |
+
dmodel,
|
203 |
+
criterion,
|
204 |
+
optimizer,
|
205 |
+
augment,
|
206 |
+
batch_size=args.batch_size,
|
207 |
+
device=device,
|
208 |
+
repeat=args.repeat,
|
209 |
+
seed=args.seed,
|
210 |
+
workers=args.workers,
|
211 |
+
world_size=args.world_size)
|
212 |
+
model.eval()
|
213 |
+
valid_loss = validate_model(epoch,
|
214 |
+
valid_set,
|
215 |
+
model,
|
216 |
+
criterion,
|
217 |
+
device=device,
|
218 |
+
rank=args.rank,
|
219 |
+
split=args.split_valid,
|
220 |
+
world_size=args.world_size)
|
221 |
+
|
222 |
+
duration = time.time() - begin
|
223 |
+
if valid_loss < best_loss:
|
224 |
+
best_loss = valid_loss
|
225 |
+
saved.best_state = {
|
226 |
+
key: value.to("cpu").clone()
|
227 |
+
for key, value in model.state_dict().items()
|
228 |
+
}
|
229 |
+
saved.metrics.append({
|
230 |
+
"train": train_loss,
|
231 |
+
"valid": valid_loss,
|
232 |
+
"best": best_loss,
|
233 |
+
"duration": duration
|
234 |
+
})
|
235 |
+
if args.rank == 0:
|
236 |
+
json.dump(saved.metrics, open(metrics_path, "w"))
|
237 |
+
|
238 |
+
saved.last_state = model.state_dict()
|
239 |
+
saved.optimizer = optimizer.state_dict()
|
240 |
+
if args.rank == 0 and not args.test:
|
241 |
+
th.save(saved, checkpoint_tmp)
|
242 |
+
checkpoint_tmp.rename(checkpoint)
|
243 |
+
|
244 |
+
print(f"Epoch {epoch:03d}: "
|
245 |
+
f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} "
|
246 |
+
f"duration={human_seconds(duration)}")
|
247 |
+
|
248 |
+
del dmodel
|
249 |
+
model.load_state_dict(saved.best_state)
|
250 |
+
if args.eval_cpu:
|
251 |
+
device = "cpu"
|
252 |
+
model.to(device)
|
253 |
+
model.eval()
|
254 |
+
evaluate(model,
|
255 |
+
args.musdb,
|
256 |
+
eval_folder,
|
257 |
+
rank=args.rank,
|
258 |
+
world_size=args.world_size,
|
259 |
+
device=device,
|
260 |
+
save=args.save,
|
261 |
+
split=args.split_valid,
|
262 |
+
shifts=args.shifts,
|
263 |
+
workers=args.eval_workers)
|
264 |
+
model.to("cpu")
|
265 |
+
save_model(model, args.models / f"{name}.th")
|
266 |
+
if args.rank == 0:
|
267 |
+
print("done")
|
268 |
+
done.write_text("done")
|
269 |
+
|
270 |
+
|
271 |
+
if __name__ == "__main__":
|
272 |
+
main()
|
demucs/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (166 Bytes). View file
|
|
demucs/__pycache__/apply.cpython-310.pyc
ADDED
Binary file (8.15 kB). View file
|
|
demucs/__pycache__/demucs.cpython-310.pyc
ADDED
Binary file (14.1 kB). View file
|
|
demucs/__pycache__/filtering.cpython-310.pyc
ADDED
Binary file (17 kB). View file
|
|
demucs/__pycache__/hdemucs.cpython-310.pyc
ADDED
Binary file (21.3 kB). View file
|
|
demucs/__pycache__/model.cpython-310.pyc
ADDED
Binary file (6.25 kB). View file
|
|
demucs/__pycache__/model_v2.cpython-310.pyc
ADDED
Binary file (6.33 kB). View file
|
|
demucs/__pycache__/pretrained.cpython-310.pyc
ADDED
Binary file (5.07 kB). View file
|
|
demucs/__pycache__/repo.cpython-310.pyc
ADDED
Binary file (5.97 kB). View file
|
|
demucs/__pycache__/spec.cpython-310.pyc
ADDED
Binary file (1.26 kB). View file
|
|
demucs/__pycache__/states.cpython-310.pyc
ADDED
Binary file (4.47 kB). View file
|
|
demucs/__pycache__/tasnet_v2.cpython-310.pyc
ADDED
Binary file (12 kB). View file
|
|
demucs/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (13.9 kB). View file
|
|
demucs/apply.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Code to apply a model to a mix. It will handle chunking with overlaps and
|
8 |
+
inteprolation between chunks, as well as the "shift trick".
|
9 |
+
"""
|
10 |
+
from concurrent.futures import ThreadPoolExecutor
|
11 |
+
import random
|
12 |
+
import typing as tp
|
13 |
+
from multiprocessing import Process,Queue,Pipe
|
14 |
+
|
15 |
+
import torch as th
|
16 |
+
from torch import nn
|
17 |
+
from torch.nn import functional as F
|
18 |
+
import tqdm
|
19 |
+
import tkinter as tk
|
20 |
+
|
21 |
+
from .demucs import Demucs
|
22 |
+
from .hdemucs import HDemucs
|
23 |
+
from .utils import center_trim, DummyPoolExecutor
|
24 |
+
|
25 |
+
Model = tp.Union[Demucs, HDemucs]
|
26 |
+
|
27 |
+
progress_bar_num = 0
|
28 |
+
|
29 |
+
class BagOfModels(nn.Module):
|
30 |
+
def __init__(self, models: tp.List[Model],
|
31 |
+
weights: tp.Optional[tp.List[tp.List[float]]] = None,
|
32 |
+
segment: tp.Optional[float] = None):
|
33 |
+
"""
|
34 |
+
Represents a bag of models with specific weights.
|
35 |
+
You should call `apply_model` rather than calling directly the forward here for
|
36 |
+
optimal performance.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
models (list[nn.Module]): list of Demucs/HDemucs models.
|
40 |
+
weights (list[list[float]]): list of weights. If None, assumed to
|
41 |
+
be all ones, otherwise it should be a list of N list (N number of models),
|
42 |
+
each containing S floats (S number of sources).
|
43 |
+
segment (None or float): overrides the `segment` attribute of each model
|
44 |
+
(this is performed inplace, be careful if you reuse the models passed).
|
45 |
+
"""
|
46 |
+
|
47 |
+
super().__init__()
|
48 |
+
assert len(models) > 0
|
49 |
+
first = models[0]
|
50 |
+
for other in models:
|
51 |
+
assert other.sources == first.sources
|
52 |
+
assert other.samplerate == first.samplerate
|
53 |
+
assert other.audio_channels == first.audio_channels
|
54 |
+
if segment is not None:
|
55 |
+
other.segment = segment
|
56 |
+
|
57 |
+
self.audio_channels = first.audio_channels
|
58 |
+
self.samplerate = first.samplerate
|
59 |
+
self.sources = first.sources
|
60 |
+
self.models = nn.ModuleList(models)
|
61 |
+
|
62 |
+
if weights is None:
|
63 |
+
weights = [[1. for _ in first.sources] for _ in models]
|
64 |
+
else:
|
65 |
+
assert len(weights) == len(models)
|
66 |
+
for weight in weights:
|
67 |
+
assert len(weight) == len(first.sources)
|
68 |
+
self.weights = weights
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
raise NotImplementedError("Call `apply_model` on this.")
|
72 |
+
|
73 |
+
class TensorChunk:
|
74 |
+
def __init__(self, tensor, offset=0, length=None):
|
75 |
+
total_length = tensor.shape[-1]
|
76 |
+
assert offset >= 0
|
77 |
+
assert offset < total_length
|
78 |
+
|
79 |
+
if length is None:
|
80 |
+
length = total_length - offset
|
81 |
+
else:
|
82 |
+
length = min(total_length - offset, length)
|
83 |
+
|
84 |
+
if isinstance(tensor, TensorChunk):
|
85 |
+
self.tensor = tensor.tensor
|
86 |
+
self.offset = offset + tensor.offset
|
87 |
+
else:
|
88 |
+
self.tensor = tensor
|
89 |
+
self.offset = offset
|
90 |
+
self.length = length
|
91 |
+
self.device = tensor.device
|
92 |
+
|
93 |
+
@property
|
94 |
+
def shape(self):
|
95 |
+
shape = list(self.tensor.shape)
|
96 |
+
shape[-1] = self.length
|
97 |
+
return shape
|
98 |
+
|
99 |
+
def padded(self, target_length):
|
100 |
+
delta = target_length - self.length
|
101 |
+
total_length = self.tensor.shape[-1]
|
102 |
+
assert delta >= 0
|
103 |
+
|
104 |
+
start = self.offset - delta // 2
|
105 |
+
end = start + target_length
|
106 |
+
|
107 |
+
correct_start = max(0, start)
|
108 |
+
correct_end = min(total_length, end)
|
109 |
+
|
110 |
+
pad_left = correct_start - start
|
111 |
+
pad_right = end - correct_end
|
112 |
+
|
113 |
+
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
114 |
+
assert out.shape[-1] == target_length
|
115 |
+
return out
|
116 |
+
|
117 |
+
def tensor_chunk(tensor_or_chunk):
|
118 |
+
if isinstance(tensor_or_chunk, TensorChunk):
|
119 |
+
return tensor_or_chunk
|
120 |
+
else:
|
121 |
+
assert isinstance(tensor_or_chunk, th.Tensor)
|
122 |
+
return TensorChunk(tensor_or_chunk)
|
123 |
+
|
124 |
+
def apply_model(model,
|
125 |
+
mix,
|
126 |
+
shifts=1,
|
127 |
+
split=True,
|
128 |
+
overlap=0.25,
|
129 |
+
transition_power=1.,
|
130 |
+
static_shifts=1,
|
131 |
+
set_progress_bar=None,
|
132 |
+
device=None,
|
133 |
+
progress=False,
|
134 |
+
num_workers=0,
|
135 |
+
pool=None):
|
136 |
+
"""
|
137 |
+
Apply model to a given mixture.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
141 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
142 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
143 |
+
and improves SDR by up to 0.2 points.
|
144 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
145 |
+
and predictions will be performed individually on each and concatenated.
|
146 |
+
Useful for model with large memory footprint like Tasnet.
|
147 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
148 |
+
device (torch.device, str, or None): if provided, device on which to
|
149 |
+
execute the computation, otherwise `mix.device` is assumed.
|
150 |
+
When `device` is different from `mix.device`, only local computations will
|
151 |
+
be on `device`, while the entire tracks will be stored on `mix.device`.
|
152 |
+
"""
|
153 |
+
|
154 |
+
global fut_length
|
155 |
+
global bag_num
|
156 |
+
global prog_bar
|
157 |
+
|
158 |
+
if device is None:
|
159 |
+
device = mix.device
|
160 |
+
else:
|
161 |
+
device = th.device(device)
|
162 |
+
if pool is None:
|
163 |
+
if num_workers > 0 and device.type == 'cpu':
|
164 |
+
pool = ThreadPoolExecutor(num_workers)
|
165 |
+
else:
|
166 |
+
pool = DummyPoolExecutor()
|
167 |
+
|
168 |
+
kwargs = {
|
169 |
+
'shifts': shifts,
|
170 |
+
'split': split,
|
171 |
+
'overlap': overlap,
|
172 |
+
'transition_power': transition_power,
|
173 |
+
'progress': progress,
|
174 |
+
'device': device,
|
175 |
+
'pool': pool,
|
176 |
+
'set_progress_bar': set_progress_bar,
|
177 |
+
'static_shifts': static_shifts,
|
178 |
+
}
|
179 |
+
|
180 |
+
if isinstance(model, BagOfModels):
|
181 |
+
# Special treatment for bag of model.
|
182 |
+
# We explicitely apply multiple times `apply_model` so that the random shifts
|
183 |
+
# are different for each model.
|
184 |
+
|
185 |
+
estimates = 0
|
186 |
+
totals = [0] * len(model.sources)
|
187 |
+
bag_num = len(model.models)
|
188 |
+
fut_length = 0
|
189 |
+
prog_bar = 0
|
190 |
+
current_model = 0 #(bag_num + 1)
|
191 |
+
for sub_model, weight in zip(model.models, model.weights):
|
192 |
+
original_model_device = next(iter(sub_model.parameters())).device
|
193 |
+
sub_model.to(device)
|
194 |
+
fut_length += fut_length
|
195 |
+
current_model += 1
|
196 |
+
out = apply_model(sub_model, mix, **kwargs)
|
197 |
+
sub_model.to(original_model_device)
|
198 |
+
for k, inst_weight in enumerate(weight):
|
199 |
+
out[:, k, :, :] *= inst_weight
|
200 |
+
totals[k] += inst_weight
|
201 |
+
estimates += out
|
202 |
+
del out
|
203 |
+
|
204 |
+
for k in range(estimates.shape[1]):
|
205 |
+
estimates[:, k, :, :] /= totals[k]
|
206 |
+
return estimates
|
207 |
+
|
208 |
+
model.to(device)
|
209 |
+
model.eval()
|
210 |
+
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
211 |
+
batch, channels, length = mix.shape
|
212 |
+
|
213 |
+
if shifts:
|
214 |
+
kwargs['shifts'] = 0
|
215 |
+
max_shift = int(0.5 * model.samplerate)
|
216 |
+
mix = tensor_chunk(mix)
|
217 |
+
padded_mix = mix.padded(length + 2 * max_shift)
|
218 |
+
out = 0
|
219 |
+
for _ in range(shifts):
|
220 |
+
offset = random.randint(0, max_shift)
|
221 |
+
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
222 |
+
shifted_out = apply_model(model, shifted, **kwargs)
|
223 |
+
out += shifted_out[..., max_shift - offset:]
|
224 |
+
out /= shifts
|
225 |
+
return out
|
226 |
+
elif split:
|
227 |
+
kwargs['split'] = False
|
228 |
+
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
|
229 |
+
sum_weight = th.zeros(length, device=mix.device)
|
230 |
+
segment = int(model.samplerate * model.segment)
|
231 |
+
stride = int((1 - overlap) * segment)
|
232 |
+
offsets = range(0, length, stride)
|
233 |
+
scale = float(format(stride / model.samplerate, ".2f"))
|
234 |
+
# We start from a triangle shaped weight, with maximal weight in the middle
|
235 |
+
# of the segment. Then we normalize and take to the power `transition_power`.
|
236 |
+
# Large values of transition power will lead to sharper transitions.
|
237 |
+
weight = th.cat([th.arange(1, segment // 2 + 1, device=device),
|
238 |
+
th.arange(segment - segment // 2, 0, -1, device=device)])
|
239 |
+
assert len(weight) == segment
|
240 |
+
# If the overlap < 50%, this will translate to linear transition when
|
241 |
+
# transition_power is 1.
|
242 |
+
weight = (weight / weight.max())**transition_power
|
243 |
+
futures = []
|
244 |
+
for offset in offsets:
|
245 |
+
chunk = TensorChunk(mix, offset, segment)
|
246 |
+
future = pool.submit(apply_model, model, chunk, **kwargs)
|
247 |
+
futures.append((future, offset))
|
248 |
+
offset += segment
|
249 |
+
if progress:
|
250 |
+
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
|
251 |
+
for future, offset in futures:
|
252 |
+
if set_progress_bar:
|
253 |
+
fut_length = (len(futures) * bag_num * static_shifts)
|
254 |
+
prog_bar += 1
|
255 |
+
set_progress_bar(0.1, (0.8/fut_length*prog_bar))
|
256 |
+
chunk_out = future.result()
|
257 |
+
chunk_length = chunk_out.shape[-1]
|
258 |
+
out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
|
259 |
+
sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device)
|
260 |
+
assert sum_weight.min() > 0
|
261 |
+
out /= sum_weight
|
262 |
+
return out
|
263 |
+
else:
|
264 |
+
if hasattr(model, 'valid_length'):
|
265 |
+
valid_length = model.valid_length(length)
|
266 |
+
else:
|
267 |
+
valid_length = length
|
268 |
+
mix = tensor_chunk(mix)
|
269 |
+
padded_mix = mix.padded(valid_length).to(device)
|
270 |
+
with th.no_grad():
|
271 |
+
out = model(padded_mix)
|
272 |
+
return center_trim(out, length)
|
273 |
+
|
274 |
+
def demucs_segments(demucs_segment, demucs_model):
|
275 |
+
|
276 |
+
if demucs_segment == 'Default':
|
277 |
+
segment = None
|
278 |
+
if isinstance(demucs_model, BagOfModels):
|
279 |
+
if segment is not None:
|
280 |
+
for sub in demucs_model.models:
|
281 |
+
sub.segment = segment
|
282 |
+
else:
|
283 |
+
if segment is not None:
|
284 |
+
sub.segment = segment
|
285 |
+
else:
|
286 |
+
try:
|
287 |
+
segment = int(demucs_segment)
|
288 |
+
if isinstance(demucs_model, BagOfModels):
|
289 |
+
if segment is not None:
|
290 |
+
for sub in demucs_model.models:
|
291 |
+
sub.segment = segment
|
292 |
+
else:
|
293 |
+
if segment is not None:
|
294 |
+
sub.segment = segment
|
295 |
+
except:
|
296 |
+
segment = None
|
297 |
+
if isinstance(demucs_model, BagOfModels):
|
298 |
+
if segment is not None:
|
299 |
+
for sub in demucs_model.models:
|
300 |
+
sub.segment = segment
|
301 |
+
else:
|
302 |
+
if segment is not None:
|
303 |
+
sub.segment = segment
|
304 |
+
|
305 |
+
return demucs_model
|
demucs/demucs.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import julius
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from .states import capture_init
|
16 |
+
from .utils import center_trim, unfold
|
17 |
+
|
18 |
+
|
19 |
+
class BLSTM(nn.Module):
|
20 |
+
"""
|
21 |
+
BiLSTM with same hidden units as input dim.
|
22 |
+
If `max_steps` is not None, input will be splitting in overlapping
|
23 |
+
chunks and the LSTM applied separately on each chunk.
|
24 |
+
"""
|
25 |
+
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
26 |
+
super().__init__()
|
27 |
+
assert max_steps is None or max_steps % 4 == 0
|
28 |
+
self.max_steps = max_steps
|
29 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
30 |
+
self.linear = nn.Linear(2 * dim, dim)
|
31 |
+
self.skip = skip
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
B, C, T = x.shape
|
35 |
+
y = x
|
36 |
+
framed = False
|
37 |
+
if self.max_steps is not None and T > self.max_steps:
|
38 |
+
width = self.max_steps
|
39 |
+
stride = width // 2
|
40 |
+
frames = unfold(x, width, stride)
|
41 |
+
nframes = frames.shape[2]
|
42 |
+
framed = True
|
43 |
+
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
44 |
+
|
45 |
+
x = x.permute(2, 0, 1)
|
46 |
+
|
47 |
+
x = self.lstm(x)[0]
|
48 |
+
x = self.linear(x)
|
49 |
+
x = x.permute(1, 2, 0)
|
50 |
+
if framed:
|
51 |
+
out = []
|
52 |
+
frames = x.reshape(B, -1, C, width)
|
53 |
+
limit = stride // 2
|
54 |
+
for k in range(nframes):
|
55 |
+
if k == 0:
|
56 |
+
out.append(frames[:, k, :, :-limit])
|
57 |
+
elif k == nframes - 1:
|
58 |
+
out.append(frames[:, k, :, limit:])
|
59 |
+
else:
|
60 |
+
out.append(frames[:, k, :, limit:-limit])
|
61 |
+
out = torch.cat(out, -1)
|
62 |
+
out = out[..., :T]
|
63 |
+
x = out
|
64 |
+
if self.skip:
|
65 |
+
x = x + y
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
def rescale_conv(conv, reference):
|
70 |
+
"""Rescale initial weight scale. It is unclear why it helps but it certainly does.
|
71 |
+
"""
|
72 |
+
std = conv.weight.std().detach()
|
73 |
+
scale = (std / reference)**0.5
|
74 |
+
conv.weight.data /= scale
|
75 |
+
if conv.bias is not None:
|
76 |
+
conv.bias.data /= scale
|
77 |
+
|
78 |
+
|
79 |
+
def rescale_module(module, reference):
|
80 |
+
for sub in module.modules():
|
81 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
|
82 |
+
rescale_conv(sub, reference)
|
83 |
+
|
84 |
+
|
85 |
+
class LayerScale(nn.Module):
|
86 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
87 |
+
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
88 |
+
"""
|
89 |
+
def __init__(self, channels: int, init: float = 0):
|
90 |
+
super().__init__()
|
91 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
92 |
+
self.scale.data[:] = init
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
return self.scale[:, None] * x
|
96 |
+
|
97 |
+
|
98 |
+
class DConv(nn.Module):
|
99 |
+
"""
|
100 |
+
New residual branches in each encoder layer.
|
101 |
+
This alternates dilated convolutions, potentially with LSTMs and attention.
|
102 |
+
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
103 |
+
e.g. of dim `channels // compress`.
|
104 |
+
"""
|
105 |
+
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
|
106 |
+
norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
|
107 |
+
kernel=3, dilate=True):
|
108 |
+
"""
|
109 |
+
Args:
|
110 |
+
channels: input/output channels for residual branch.
|
111 |
+
compress: amount of channel compression inside the branch.
|
112 |
+
depth: number of layers in the residual branch. Each layer has its own
|
113 |
+
projection, and potentially LSTM and attention.
|
114 |
+
init: initial scale for LayerNorm.
|
115 |
+
norm: use GroupNorm.
|
116 |
+
attn: use LocalAttention.
|
117 |
+
heads: number of heads for the LocalAttention.
|
118 |
+
ndecay: number of decay controls in the LocalAttention.
|
119 |
+
lstm: use LSTM.
|
120 |
+
gelu: Use GELU activation.
|
121 |
+
kernel: kernel size for the (dilated) convolutions.
|
122 |
+
dilate: if true, use dilation, increasing with the depth.
|
123 |
+
"""
|
124 |
+
|
125 |
+
super().__init__()
|
126 |
+
assert kernel % 2 == 1
|
127 |
+
self.channels = channels
|
128 |
+
self.compress = compress
|
129 |
+
self.depth = abs(depth)
|
130 |
+
dilate = depth > 0
|
131 |
+
|
132 |
+
norm_fn: tp.Callable[[int], nn.Module]
|
133 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
134 |
+
if norm:
|
135 |
+
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
136 |
+
|
137 |
+
hidden = int(channels / compress)
|
138 |
+
|
139 |
+
act: tp.Type[nn.Module]
|
140 |
+
if gelu:
|
141 |
+
act = nn.GELU
|
142 |
+
else:
|
143 |
+
act = nn.ReLU
|
144 |
+
|
145 |
+
self.layers = nn.ModuleList([])
|
146 |
+
for d in range(self.depth):
|
147 |
+
dilation = 2 ** d if dilate else 1
|
148 |
+
padding = dilation * (kernel // 2)
|
149 |
+
mods = [
|
150 |
+
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
|
151 |
+
norm_fn(hidden), act(),
|
152 |
+
nn.Conv1d(hidden, 2 * channels, 1),
|
153 |
+
norm_fn(2 * channels), nn.GLU(1),
|
154 |
+
LayerScale(channels, init),
|
155 |
+
]
|
156 |
+
if attn:
|
157 |
+
mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
|
158 |
+
if lstm:
|
159 |
+
mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
|
160 |
+
layer = nn.Sequential(*mods)
|
161 |
+
self.layers.append(layer)
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
for layer in self.layers:
|
165 |
+
x = x + layer(x)
|
166 |
+
return x
|
167 |
+
|
168 |
+
|
169 |
+
class LocalState(nn.Module):
|
170 |
+
"""Local state allows to have attention based only on data (no positional embedding),
|
171 |
+
but while setting a constraint on the time window (e.g. decaying penalty term).
|
172 |
+
|
173 |
+
Also a failed experiments with trying to provide some frequency based attention.
|
174 |
+
"""
|
175 |
+
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
176 |
+
super().__init__()
|
177 |
+
assert channels % heads == 0, (channels, heads)
|
178 |
+
self.heads = heads
|
179 |
+
self.nfreqs = nfreqs
|
180 |
+
self.ndecay = ndecay
|
181 |
+
self.content = nn.Conv1d(channels, channels, 1)
|
182 |
+
self.query = nn.Conv1d(channels, channels, 1)
|
183 |
+
self.key = nn.Conv1d(channels, channels, 1)
|
184 |
+
if nfreqs:
|
185 |
+
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
186 |
+
if ndecay:
|
187 |
+
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
188 |
+
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
189 |
+
self.query_decay.weight.data *= 0.01
|
190 |
+
assert self.query_decay.bias is not None # stupid type checker
|
191 |
+
self.query_decay.bias.data[:] = -2
|
192 |
+
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
B, C, T = x.shape
|
196 |
+
heads = self.heads
|
197 |
+
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
198 |
+
# left index are keys, right index are queries
|
199 |
+
delta = indexes[:, None] - indexes[None, :]
|
200 |
+
|
201 |
+
queries = self.query(x).view(B, heads, -1, T)
|
202 |
+
keys = self.key(x).view(B, heads, -1, T)
|
203 |
+
# t are keys, s are queries
|
204 |
+
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
205 |
+
dots /= keys.shape[2]**0.5
|
206 |
+
if self.nfreqs:
|
207 |
+
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
208 |
+
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
209 |
+
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
|
210 |
+
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
211 |
+
if self.ndecay:
|
212 |
+
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
213 |
+
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
214 |
+
decay_q = torch.sigmoid(decay_q) / 2
|
215 |
+
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
|
216 |
+
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
217 |
+
|
218 |
+
# Kill self reference.
|
219 |
+
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
220 |
+
weights = torch.softmax(dots, dim=2)
|
221 |
+
|
222 |
+
content = self.content(x).view(B, heads, -1, T)
|
223 |
+
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
224 |
+
if self.nfreqs:
|
225 |
+
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
|
226 |
+
result = torch.cat([result, time_sig], 2)
|
227 |
+
result = result.reshape(B, -1, T)
|
228 |
+
return x + self.proj(result)
|
229 |
+
|
230 |
+
|
231 |
+
class Demucs(nn.Module):
|
232 |
+
@capture_init
|
233 |
+
def __init__(self,
|
234 |
+
sources,
|
235 |
+
# Channels
|
236 |
+
audio_channels=2,
|
237 |
+
channels=64,
|
238 |
+
growth=2.,
|
239 |
+
# Main structure
|
240 |
+
depth=6,
|
241 |
+
rewrite=True,
|
242 |
+
lstm_layers=0,
|
243 |
+
# Convolutions
|
244 |
+
kernel_size=8,
|
245 |
+
stride=4,
|
246 |
+
context=1,
|
247 |
+
# Activations
|
248 |
+
gelu=True,
|
249 |
+
glu=True,
|
250 |
+
# Normalization
|
251 |
+
norm_starts=4,
|
252 |
+
norm_groups=4,
|
253 |
+
# DConv residual branch
|
254 |
+
dconv_mode=1,
|
255 |
+
dconv_depth=2,
|
256 |
+
dconv_comp=4,
|
257 |
+
dconv_attn=4,
|
258 |
+
dconv_lstm=4,
|
259 |
+
dconv_init=1e-4,
|
260 |
+
# Pre/post processing
|
261 |
+
normalize=True,
|
262 |
+
resample=True,
|
263 |
+
# Weight init
|
264 |
+
rescale=0.1,
|
265 |
+
# Metadata
|
266 |
+
samplerate=44100,
|
267 |
+
segment=4 * 10):
|
268 |
+
"""
|
269 |
+
Args:
|
270 |
+
sources (list[str]): list of source names
|
271 |
+
audio_channels (int): stereo or mono
|
272 |
+
channels (int): first convolution channels
|
273 |
+
depth (int): number of encoder/decoder layers
|
274 |
+
growth (float): multiply (resp divide) number of channels by that
|
275 |
+
for each layer of the encoder (resp decoder)
|
276 |
+
depth (int): number of layers in the encoder and in the decoder.
|
277 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
278 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
|
279 |
+
by default, as this is now replaced by the smaller and faster small LSTMs
|
280 |
+
in the DConv branches.
|
281 |
+
kernel_size (int): kernel size for convolutions
|
282 |
+
stride (int): stride for convolutions
|
283 |
+
context (int): kernel size of the convolution in the
|
284 |
+
decoder before the transposed convolution. If > 1,
|
285 |
+
will provide some context from neighboring time steps.
|
286 |
+
gelu: use GELU activation function.
|
287 |
+
glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
|
288 |
+
norm_starts: layer at which group norm starts being used.
|
289 |
+
decoder layers are numbered in reverse order.
|
290 |
+
norm_groups: number of groups for group norm.
|
291 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
292 |
+
dconv_depth: depth of residual DConv branch.
|
293 |
+
dconv_comp: compression of DConv branch.
|
294 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
295 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
296 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
297 |
+
normalize (bool): normalizes the input audio on the fly, and scales back
|
298 |
+
the output by the same amount.
|
299 |
+
resample (bool): upsample x2 the input and downsample /2 the output.
|
300 |
+
rescale (int): rescale initial weights of convolutions
|
301 |
+
to get their standard deviation closer to `rescale`.
|
302 |
+
samplerate (int): stored as meta information for easing
|
303 |
+
future evaluations of the model.
|
304 |
+
segment (float): duration of the chunks of audio to ideally evaluate the model on.
|
305 |
+
This is used by `demucs.apply.apply_model`.
|
306 |
+
"""
|
307 |
+
|
308 |
+
super().__init__()
|
309 |
+
self.audio_channels = audio_channels
|
310 |
+
self.sources = sources
|
311 |
+
self.kernel_size = kernel_size
|
312 |
+
self.context = context
|
313 |
+
self.stride = stride
|
314 |
+
self.depth = depth
|
315 |
+
self.resample = resample
|
316 |
+
self.channels = channels
|
317 |
+
self.normalize = normalize
|
318 |
+
self.samplerate = samplerate
|
319 |
+
self.segment = segment
|
320 |
+
self.encoder = nn.ModuleList()
|
321 |
+
self.decoder = nn.ModuleList()
|
322 |
+
self.skip_scales = nn.ModuleList()
|
323 |
+
|
324 |
+
if glu:
|
325 |
+
activation = nn.GLU(dim=1)
|
326 |
+
ch_scale = 2
|
327 |
+
else:
|
328 |
+
activation = nn.ReLU()
|
329 |
+
ch_scale = 1
|
330 |
+
if gelu:
|
331 |
+
act2 = nn.GELU
|
332 |
+
else:
|
333 |
+
act2 = nn.ReLU
|
334 |
+
|
335 |
+
in_channels = audio_channels
|
336 |
+
padding = 0
|
337 |
+
for index in range(depth):
|
338 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
339 |
+
if index >= norm_starts:
|
340 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
341 |
+
|
342 |
+
encode = []
|
343 |
+
encode += [
|
344 |
+
nn.Conv1d(in_channels, channels, kernel_size, stride),
|
345 |
+
norm_fn(channels),
|
346 |
+
act2(),
|
347 |
+
]
|
348 |
+
attn = index >= dconv_attn
|
349 |
+
lstm = index >= dconv_lstm
|
350 |
+
if dconv_mode & 1:
|
351 |
+
encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
352 |
+
compress=dconv_comp, attn=attn, lstm=lstm)]
|
353 |
+
if rewrite:
|
354 |
+
encode += [
|
355 |
+
nn.Conv1d(channels, ch_scale * channels, 1),
|
356 |
+
norm_fn(ch_scale * channels), activation]
|
357 |
+
self.encoder.append(nn.Sequential(*encode))
|
358 |
+
|
359 |
+
decode = []
|
360 |
+
if index > 0:
|
361 |
+
out_channels = in_channels
|
362 |
+
else:
|
363 |
+
out_channels = len(self.sources) * audio_channels
|
364 |
+
if rewrite:
|
365 |
+
decode += [
|
366 |
+
nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
|
367 |
+
norm_fn(ch_scale * channels), activation]
|
368 |
+
if dconv_mode & 2:
|
369 |
+
decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
370 |
+
compress=dconv_comp, attn=attn, lstm=lstm)]
|
371 |
+
decode += [nn.ConvTranspose1d(channels, out_channels,
|
372 |
+
kernel_size, stride, padding=padding)]
|
373 |
+
if index > 0:
|
374 |
+
decode += [norm_fn(out_channels), act2()]
|
375 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
376 |
+
in_channels = channels
|
377 |
+
channels = int(growth * channels)
|
378 |
+
|
379 |
+
channels = in_channels
|
380 |
+
if lstm_layers:
|
381 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
382 |
+
else:
|
383 |
+
self.lstm = None
|
384 |
+
|
385 |
+
if rescale:
|
386 |
+
rescale_module(self, reference=rescale)
|
387 |
+
|
388 |
+
def valid_length(self, length):
|
389 |
+
"""
|
390 |
+
Return the nearest valid length to use with the model so that
|
391 |
+
there is no time steps left over in a convolution, e.g. for all
|
392 |
+
layers, size of the input - kernel_size % stride = 0.
|
393 |
+
|
394 |
+
Note that input are automatically padded if necessary to ensure that the output
|
395 |
+
has the same length as the input.
|
396 |
+
"""
|
397 |
+
if self.resample:
|
398 |
+
length *= 2
|
399 |
+
|
400 |
+
for _ in range(self.depth):
|
401 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
402 |
+
length = max(1, length)
|
403 |
+
|
404 |
+
for idx in range(self.depth):
|
405 |
+
length = (length - 1) * self.stride + self.kernel_size
|
406 |
+
|
407 |
+
if self.resample:
|
408 |
+
length = math.ceil(length / 2)
|
409 |
+
return int(length)
|
410 |
+
|
411 |
+
def forward(self, mix):
|
412 |
+
x = mix
|
413 |
+
length = x.shape[-1]
|
414 |
+
|
415 |
+
if self.normalize:
|
416 |
+
mono = mix.mean(dim=1, keepdim=True)
|
417 |
+
mean = mono.mean(dim=-1, keepdim=True)
|
418 |
+
std = mono.std(dim=-1, keepdim=True)
|
419 |
+
x = (x - mean) / (1e-5 + std)
|
420 |
+
else:
|
421 |
+
mean = 0
|
422 |
+
std = 1
|
423 |
+
|
424 |
+
delta = self.valid_length(length) - length
|
425 |
+
x = F.pad(x, (delta // 2, delta - delta // 2))
|
426 |
+
|
427 |
+
if self.resample:
|
428 |
+
x = julius.resample_frac(x, 1, 2)
|
429 |
+
|
430 |
+
saved = []
|
431 |
+
for encode in self.encoder:
|
432 |
+
x = encode(x)
|
433 |
+
saved.append(x)
|
434 |
+
|
435 |
+
if self.lstm:
|
436 |
+
x = self.lstm(x)
|
437 |
+
|
438 |
+
for decode in self.decoder:
|
439 |
+
skip = saved.pop(-1)
|
440 |
+
skip = center_trim(skip, x)
|
441 |
+
x = decode(x + skip)
|
442 |
+
|
443 |
+
if self.resample:
|
444 |
+
x = julius.resample_frac(x, 2, 1)
|
445 |
+
x = x * std + mean
|
446 |
+
x = center_trim(x, length)
|
447 |
+
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
448 |
+
return x
|
449 |
+
|
450 |
+
def load_state_dict(self, state, strict=True):
|
451 |
+
# fix a mismatch with previous generation Demucs models.
|
452 |
+
for idx in range(self.depth):
|
453 |
+
for a in ['encoder', 'decoder']:
|
454 |
+
for b in ['bias', 'weight']:
|
455 |
+
new = f'{a}.{idx}.3.{b}'
|
456 |
+
old = f'{a}.{idx}.2.{b}'
|
457 |
+
if old in state and new not in state:
|
458 |
+
state[new] = state.pop(old)
|
459 |
+
super().load_state_dict(state, strict=strict)
|
demucs/filtering.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
def atan2(y, x):
|
8 |
+
r"""Element-wise arctangent function of y/x.
|
9 |
+
Returns a new tensor with signed angles in radians.
|
10 |
+
It is an alternative implementation of torch.atan2
|
11 |
+
|
12 |
+
Args:
|
13 |
+
y (Tensor): First input tensor
|
14 |
+
x (Tensor): Second input tensor [shape=y.shape]
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
Tensor: [shape=y.shape].
|
18 |
+
"""
|
19 |
+
pi = 2 * torch.asin(torch.tensor(1.0))
|
20 |
+
x += ((x == 0) & (y == 0)) * 1.0
|
21 |
+
out = torch.atan(y / x)
|
22 |
+
out += ((y >= 0) & (x < 0)) * pi
|
23 |
+
out -= ((y < 0) & (x < 0)) * pi
|
24 |
+
out *= 1 - ((y > 0) & (x == 0)) * 1.0
|
25 |
+
out += ((y > 0) & (x == 0)) * (pi / 2)
|
26 |
+
out *= 1 - ((y < 0) & (x == 0)) * 1.0
|
27 |
+
out += ((y < 0) & (x == 0)) * (-pi / 2)
|
28 |
+
return out
|
29 |
+
|
30 |
+
|
31 |
+
# Define basic complex operations on torch.Tensor objects whose last dimension
|
32 |
+
# consists in the concatenation of the real and imaginary parts.
|
33 |
+
|
34 |
+
|
35 |
+
def _norm(x: torch.Tensor) -> torch.Tensor:
|
36 |
+
r"""Computes the norm value of a torch Tensor, assuming that it
|
37 |
+
comes as real and imaginary part in its last dimension.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
x (Tensor): Input Tensor of shape [shape=(..., 2)]
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Tensor: shape as x excluding the last dimension.
|
44 |
+
"""
|
45 |
+
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
|
46 |
+
|
47 |
+
|
48 |
+
def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
49 |
+
"""Element-wise multiplication of two complex Tensors described
|
50 |
+
through their real and imaginary parts.
|
51 |
+
The result is added to the `out` tensor"""
|
52 |
+
|
53 |
+
# check `out` and allocate it if needed
|
54 |
+
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
55 |
+
if out is None or out.shape != target_shape:
|
56 |
+
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
57 |
+
if out is a:
|
58 |
+
real_a = a[..., 0]
|
59 |
+
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
|
60 |
+
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
|
61 |
+
else:
|
62 |
+
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
|
63 |
+
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
68 |
+
"""Element-wise multiplication of two complex Tensors described
|
69 |
+
through their real and imaginary parts
|
70 |
+
can work in place in case out is a only"""
|
71 |
+
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
72 |
+
if out is None or out.shape != target_shape:
|
73 |
+
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
74 |
+
if out is a:
|
75 |
+
real_a = a[..., 0]
|
76 |
+
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
|
77 |
+
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
|
78 |
+
else:
|
79 |
+
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
|
80 |
+
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
|
81 |
+
return out
|
82 |
+
|
83 |
+
|
84 |
+
def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
85 |
+
"""Element-wise multiplicative inverse of a Tensor with complex
|
86 |
+
entries described through their real and imaginary parts.
|
87 |
+
can work in place in case out is z"""
|
88 |
+
ez = _norm(z)
|
89 |
+
if out is None or out.shape != z.shape:
|
90 |
+
out = torch.zeros_like(z)
|
91 |
+
out[..., 0] = z[..., 0] / ez
|
92 |
+
out[..., 1] = -z[..., 1] / ez
|
93 |
+
return out
|
94 |
+
|
95 |
+
|
96 |
+
def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
97 |
+
"""Element-wise complex conjugate of a Tensor with complex entries
|
98 |
+
described through their real and imaginary parts.
|
99 |
+
can work in place in case out is z"""
|
100 |
+
if out is None or out.shape != z.shape:
|
101 |
+
out = torch.zeros_like(z)
|
102 |
+
out[..., 0] = z[..., 0]
|
103 |
+
out[..., 1] = -z[..., 1]
|
104 |
+
return out
|
105 |
+
|
106 |
+
|
107 |
+
def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
108 |
+
"""
|
109 |
+
Invert 1x1 or 2x2 matrices
|
110 |
+
|
111 |
+
Will generate errors if the matrices are singular: user must handle this
|
112 |
+
through his own regularization schemes.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
|
116 |
+
matrices to invert: must be square along dimensions -3 and -2
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
invM (Tensor): [shape=M.shape]
|
120 |
+
inverses of M
|
121 |
+
"""
|
122 |
+
nb_channels = M.shape[-2]
|
123 |
+
|
124 |
+
if out is None or out.shape != M.shape:
|
125 |
+
out = torch.empty_like(M)
|
126 |
+
|
127 |
+
if nb_channels == 1:
|
128 |
+
# scalar case
|
129 |
+
out = _inv(M, out)
|
130 |
+
elif nb_channels == 2:
|
131 |
+
# two channels case: analytical expression
|
132 |
+
|
133 |
+
# first compute the determinent
|
134 |
+
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
|
135 |
+
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
|
136 |
+
# invert it
|
137 |
+
invDet = _inv(det)
|
138 |
+
|
139 |
+
# then fill out the matrix with the inverse
|
140 |
+
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
|
141 |
+
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
|
142 |
+
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
|
143 |
+
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
|
144 |
+
else:
|
145 |
+
raise Exception("Only 2 channels are supported for the torch version.")
|
146 |
+
return out
|
147 |
+
|
148 |
+
|
149 |
+
# Now define the signal-processing low-level functions used by the Separator
|
150 |
+
|
151 |
+
|
152 |
+
def expectation_maximization(
|
153 |
+
y: torch.Tensor,
|
154 |
+
x: torch.Tensor,
|
155 |
+
iterations: int = 2,
|
156 |
+
eps: float = 1e-10,
|
157 |
+
batch_size: int = 200,
|
158 |
+
):
|
159 |
+
r"""Expectation maximization algorithm, for refining source separation
|
160 |
+
estimates.
|
161 |
+
|
162 |
+
This algorithm allows to make source separation results better by
|
163 |
+
enforcing multichannel consistency for the estimates. This usually means
|
164 |
+
a better perceptual quality in terms of spatial artifacts.
|
165 |
+
|
166 |
+
The implementation follows the details presented in [1]_, taking
|
167 |
+
inspiration from the original EM algorithm proposed in [2]_ and its
|
168 |
+
weighted refinement proposed in [3]_, [4]_.
|
169 |
+
It works by iteratively:
|
170 |
+
|
171 |
+
* Re-estimate source parameters (power spectral densities and spatial
|
172 |
+
covariance matrices) through :func:`get_local_gaussian_model`.
|
173 |
+
|
174 |
+
* Separate again the mixture with the new parameters by first computing
|
175 |
+
the new modelled mixture covariance matrices with :func:`get_mix_model`,
|
176 |
+
prepare the Wiener filters through :func:`wiener_gain` and apply them
|
177 |
+
with :func:`apply_filter``.
|
178 |
+
|
179 |
+
References
|
180 |
+
----------
|
181 |
+
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
|
182 |
+
N. Takahashi and Y. Mitsufuji, "Improving music source separation based
|
183 |
+
on deep neural networks through data augmentation and network
|
184 |
+
blending." 2017 IEEE International Conference on Acoustics, Speech
|
185 |
+
and Signal Processing (ICASSP). IEEE, 2017.
|
186 |
+
|
187 |
+
.. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined
|
188 |
+
reverberant audio source separation using a full-rank spatial
|
189 |
+
covariance model." IEEE Transactions on Audio, Speech, and Language
|
190 |
+
Processing 18.7 (2010): 1830-1840.
|
191 |
+
|
192 |
+
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
|
193 |
+
separation with deep neural networks." IEEE/ACM Transactions on Audio,
|
194 |
+
Speech, and Language Processing 24.9 (2016): 1652-1664.
|
195 |
+
|
196 |
+
.. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
|
197 |
+
separation with deep neural networks." 2016 24th European Signal
|
198 |
+
Processing Conference (EUSIPCO). IEEE, 2016.
|
199 |
+
|
200 |
+
.. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
|
201 |
+
source separation." IEEE Transactions on Signal Processing
|
202 |
+
62.16 (2014): 4298-4310.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
|
206 |
+
initial estimates for the sources
|
207 |
+
x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
|
208 |
+
complex STFT of the mixture signal
|
209 |
+
iterations (int): [scalar]
|
210 |
+
number of iterations for the EM algorithm.
|
211 |
+
eps (float or None): [scalar]
|
212 |
+
The epsilon value to use for regularization and filters.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
|
216 |
+
estimated sources after iterations
|
217 |
+
v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
|
218 |
+
estimated power spectral densities
|
219 |
+
R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
|
220 |
+
estimated spatial covariance matrices
|
221 |
+
|
222 |
+
Notes:
|
223 |
+
* You need an initial estimate for the sources to apply this
|
224 |
+
algorithm. This is precisely what the :func:`wiener` function does.
|
225 |
+
* This algorithm *is not* an implementation of the "exact" EM
|
226 |
+
proposed in [1]_. In particular, it does compute the posterior
|
227 |
+
covariance matrices the same (exact) way. Instead, it uses the
|
228 |
+
simplified approximate scheme initially proposed in [5]_ and further
|
229 |
+
refined in [3]_, [4]_, that boils down to just take the empirical
|
230 |
+
covariance of the recent source estimates, followed by a weighted
|
231 |
+
average for the update of the spatial covariance matrix. It has been
|
232 |
+
empirically demonstrated that this simplified algorithm is more
|
233 |
+
robust for music separation.
|
234 |
+
|
235 |
+
Warning:
|
236 |
+
It is *very* important to make sure `x.dtype` is `torch.float64`
|
237 |
+
if you want double precision, because this function will **not**
|
238 |
+
do such conversion for you from `torch.complex32`, in case you want the
|
239 |
+
smaller RAM usage on purpose.
|
240 |
+
|
241 |
+
It is usually always better in terms of quality to have double
|
242 |
+
precision, by e.g. calling :func:`expectation_maximization`
|
243 |
+
with ``x.to(torch.float64)``.
|
244 |
+
"""
|
245 |
+
# dimensions
|
246 |
+
(nb_frames, nb_bins, nb_channels) = x.shape[:-1]
|
247 |
+
nb_sources = y.shape[-1]
|
248 |
+
|
249 |
+
regularization = torch.cat(
|
250 |
+
(
|
251 |
+
torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None],
|
252 |
+
torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device),
|
253 |
+
),
|
254 |
+
dim=2,
|
255 |
+
)
|
256 |
+
regularization = torch.sqrt(torch.as_tensor(eps)) * (
|
257 |
+
regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1))
|
258 |
+
)
|
259 |
+
|
260 |
+
# allocate the spatial covariance matrices
|
261 |
+
R = [
|
262 |
+
torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device)
|
263 |
+
for j in range(nb_sources)
|
264 |
+
]
|
265 |
+
weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
|
266 |
+
|
267 |
+
v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
|
268 |
+
for it in range(iterations):
|
269 |
+
# constructing the mixture covariance matrix. Doing it with a loop
|
270 |
+
# to avoid storing anytime in RAM the whole 6D tensor
|
271 |
+
|
272 |
+
# update the PSD as the average spectrogram over channels
|
273 |
+
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
|
274 |
+
|
275 |
+
# update spatial covariance matrices (weighted update)
|
276 |
+
for j in range(nb_sources):
|
277 |
+
R[j] = torch.tensor(0.0, device=x.device)
|
278 |
+
weight = torch.tensor(eps, device=x.device)
|
279 |
+
pos: int = 0
|
280 |
+
batch_size = batch_size if batch_size else nb_frames
|
281 |
+
while pos < nb_frames:
|
282 |
+
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
283 |
+
pos = int(t[-1]) + 1
|
284 |
+
|
285 |
+
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
|
286 |
+
weight = weight + torch.sum(v[t, ..., j], dim=0)
|
287 |
+
R[j] = R[j] / weight[..., None, None, None]
|
288 |
+
weight = torch.zeros_like(weight)
|
289 |
+
|
290 |
+
# cloning y if we track gradient, because we're going to update it
|
291 |
+
if y.requires_grad:
|
292 |
+
y = y.clone()
|
293 |
+
|
294 |
+
pos = 0
|
295 |
+
while pos < nb_frames:
|
296 |
+
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
297 |
+
pos = int(t[-1]) + 1
|
298 |
+
|
299 |
+
y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
300 |
+
|
301 |
+
# compute mix covariance matrix
|
302 |
+
Cxx = regularization
|
303 |
+
for j in range(nb_sources):
|
304 |
+
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
|
305 |
+
|
306 |
+
# invert it
|
307 |
+
inv_Cxx = _invert(Cxx)
|
308 |
+
|
309 |
+
# separate the sources
|
310 |
+
for j in range(nb_sources):
|
311 |
+
|
312 |
+
# create a wiener gain for this source
|
313 |
+
gain = torch.zeros_like(inv_Cxx)
|
314 |
+
|
315 |
+
# computes multichannel Wiener gain as v_j R_j inv_Cxx
|
316 |
+
indices = torch.cartesian_prod(
|
317 |
+
torch.arange(nb_channels),
|
318 |
+
torch.arange(nb_channels),
|
319 |
+
torch.arange(nb_channels),
|
320 |
+
)
|
321 |
+
for index in indices:
|
322 |
+
gain[:, :, index[0], index[1], :] = _mul_add(
|
323 |
+
R[j][None, :, index[0], index[2], :].clone(),
|
324 |
+
inv_Cxx[:, :, index[2], index[1], :],
|
325 |
+
gain[:, :, index[0], index[1], :],
|
326 |
+
)
|
327 |
+
gain = gain * v[t, ..., None, None, None, j]
|
328 |
+
|
329 |
+
# apply it to the mixture
|
330 |
+
for i in range(nb_channels):
|
331 |
+
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
|
332 |
+
|
333 |
+
return y, v, R
|
334 |
+
|
335 |
+
|
336 |
+
def wiener(
|
337 |
+
targets_spectrograms: torch.Tensor,
|
338 |
+
mix_stft: torch.Tensor,
|
339 |
+
iterations: int = 1,
|
340 |
+
softmask: bool = False,
|
341 |
+
residual: bool = False,
|
342 |
+
scale_factor: float = 10.0,
|
343 |
+
eps: float = 1e-10,
|
344 |
+
):
|
345 |
+
"""Wiener-based separation for multichannel audio.
|
346 |
+
|
347 |
+
The method uses the (possibly multichannel) spectrograms of the
|
348 |
+
sources to separate the (complex) Short Term Fourier Transform of the
|
349 |
+
mix. Separation is done in a sequential way by:
|
350 |
+
|
351 |
+
* Getting an initial estimate. This can be done in two ways: either by
|
352 |
+
directly using the spectrograms with the mixture phase, or
|
353 |
+
by using a softmasking strategy. This initial phase is controlled
|
354 |
+
by the `softmask` flag.
|
355 |
+
|
356 |
+
* If required, adding an additional residual target as the mix minus
|
357 |
+
all targets.
|
358 |
+
|
359 |
+
* Refinining these initial estimates through a call to
|
360 |
+
:func:`expectation_maximization` if the number of iterations is nonzero.
|
361 |
+
|
362 |
+
This implementation also allows to specify the epsilon value used for
|
363 |
+
regularization. It is based on [1]_, [2]_, [3]_, [4]_.
|
364 |
+
|
365 |
+
References
|
366 |
+
----------
|
367 |
+
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
|
368 |
+
N. Takahashi and Y. Mitsufuji, "Improving music source separation based
|
369 |
+
on deep neural networks through data augmentation and network
|
370 |
+
blending." 2017 IEEE International Conference on Acoustics, Speech
|
371 |
+
and Signal Processing (ICASSP). IEEE, 2017.
|
372 |
+
|
373 |
+
.. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
|
374 |
+
separation with deep neural networks." IEEE/ACM Transactions on Audio,
|
375 |
+
Speech, and Language Processing 24.9 (2016): 1652-1664.
|
376 |
+
|
377 |
+
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
|
378 |
+
separation with deep neural networks." 2016 24th European Signal
|
379 |
+
Processing Conference (EUSIPCO). IEEE, 2016.
|
380 |
+
|
381 |
+
.. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
|
382 |
+
source separation." IEEE Transactions on Signal Processing
|
383 |
+
62.16 (2014): 4298-4310.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
targets_spectrograms (Tensor): spectrograms of the sources
|
387 |
+
[shape=(nb_frames, nb_bins, nb_channels, nb_sources)].
|
388 |
+
This is a nonnegative tensor that is
|
389 |
+
usually the output of the actual separation method of the user. The
|
390 |
+
spectrograms may be mono, but they need to be 4-dimensional in all
|
391 |
+
cases.
|
392 |
+
mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)]
|
393 |
+
STFT of the mixture signal.
|
394 |
+
iterations (int): [scalar]
|
395 |
+
number of iterations for the EM algorithm
|
396 |
+
softmask (bool): Describes how the initial estimates are obtained.
|
397 |
+
* if `False`, then the mixture phase will directly be used with the
|
398 |
+
spectrogram as initial estimates.
|
399 |
+
* if `True`, initial estimates are obtained by multiplying the
|
400 |
+
complex mix element-wise with the ratio of each target spectrogram
|
401 |
+
with the sum of them all. This strategy is better if the model are
|
402 |
+
not really good, and worse otherwise.
|
403 |
+
residual (bool): if `True`, an additional target is created, which is
|
404 |
+
equal to the mixture minus the other targets, before application of
|
405 |
+
expectation maximization
|
406 |
+
eps (float): Epsilon value to use for computing the separations.
|
407 |
+
This is used whenever division with a model energy is
|
408 |
+
performed, i.e. when softmasking and when iterating the EM.
|
409 |
+
It can be understood as the energy of the additional white noise
|
410 |
+
that is taken out when separating.
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
|
414 |
+
STFT of estimated sources
|
415 |
+
|
416 |
+
Notes:
|
417 |
+
* Be careful that you need *magnitude spectrogram estimates* for the
|
418 |
+
case `softmask==False`.
|
419 |
+
* `softmask=False` is recommended
|
420 |
+
* The epsilon value will have a huge impact on performance. If it's
|
421 |
+
large, only the parts of the signal with a significant energy will
|
422 |
+
be kept in the sources. This epsilon then directly controls the
|
423 |
+
energy of the reconstruction error.
|
424 |
+
|
425 |
+
Warning:
|
426 |
+
As in :func:`expectation_maximization`, we recommend converting the
|
427 |
+
mixture `x` to double precision `torch.float64` *before* calling
|
428 |
+
:func:`wiener`.
|
429 |
+
"""
|
430 |
+
if softmask:
|
431 |
+
# if we use softmask, we compute the ratio mask for all targets and
|
432 |
+
# multiply by the mix stft
|
433 |
+
y = (
|
434 |
+
mix_stft[..., None]
|
435 |
+
* (
|
436 |
+
targets_spectrograms
|
437 |
+
/ (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype))
|
438 |
+
)[..., None, :]
|
439 |
+
)
|
440 |
+
else:
|
441 |
+
# otherwise, we just multiply the targets spectrograms with mix phase
|
442 |
+
# we tacitly assume that we have magnitude estimates.
|
443 |
+
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
|
444 |
+
nb_sources = targets_spectrograms.shape[-1]
|
445 |
+
y = torch.zeros(
|
446 |
+
mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device
|
447 |
+
)
|
448 |
+
y[..., 0, :] = targets_spectrograms * torch.cos(angle)
|
449 |
+
y[..., 1, :] = targets_spectrograms * torch.sin(angle)
|
450 |
+
|
451 |
+
if residual:
|
452 |
+
# if required, adding an additional target as the mix minus
|
453 |
+
# available targets
|
454 |
+
y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
|
455 |
+
|
456 |
+
if iterations == 0:
|
457 |
+
return y
|
458 |
+
|
459 |
+
# we need to refine the estimates. Scales down the estimates for
|
460 |
+
# numerical stability
|
461 |
+
max_abs = torch.max(
|
462 |
+
torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device),
|
463 |
+
torch.sqrt(_norm(mix_stft)).max() / scale_factor,
|
464 |
+
)
|
465 |
+
|
466 |
+
mix_stft = mix_stft / max_abs
|
467 |
+
y = y / max_abs
|
468 |
+
|
469 |
+
# call expectation maximization
|
470 |
+
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
|
471 |
+
|
472 |
+
# scale estimates up again
|
473 |
+
y = y * max_abs
|
474 |
+
return y
|
475 |
+
|
476 |
+
|
477 |
+
def _covariance(y_j):
|
478 |
+
"""
|
479 |
+
Compute the empirical covariance for a source.
|
480 |
+
|
481 |
+
Args:
|
482 |
+
y_j (Tensor): complex stft of the source.
|
483 |
+
[shape=(nb_frames, nb_bins, nb_channels, 2)].
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
|
487 |
+
just y_j * conj(y_j.T): empirical covariance for each TF bin.
|
488 |
+
"""
|
489 |
+
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
|
490 |
+
Cj = torch.zeros(
|
491 |
+
(nb_frames, nb_bins, nb_channels, nb_channels, 2),
|
492 |
+
dtype=y_j.dtype,
|
493 |
+
device=y_j.device,
|
494 |
+
)
|
495 |
+
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
|
496 |
+
for index in indices:
|
497 |
+
Cj[:, :, index[0], index[1], :] = _mul_add(
|
498 |
+
y_j[:, :, index[0], :],
|
499 |
+
_conj(y_j[:, :, index[1], :]),
|
500 |
+
Cj[:, :, index[0], index[1], :],
|
501 |
+
)
|
502 |
+
return Cj
|
demucs/hdemucs.py
ADDED
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
This code contains the spectrogram and Hybrid version of Demucs.
|
8 |
+
"""
|
9 |
+
from copy import deepcopy
|
10 |
+
import math
|
11 |
+
import typing as tp
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from .filtering import wiener
|
16 |
+
from .demucs import DConv, rescale_module
|
17 |
+
from .states import capture_init
|
18 |
+
from .spec import spectro, ispectro
|
19 |
+
|
20 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
21 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
22 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen."""
|
23 |
+
x0 = x
|
24 |
+
length = x.shape[-1]
|
25 |
+
padding_left, padding_right = paddings
|
26 |
+
if mode == 'reflect':
|
27 |
+
max_pad = max(padding_left, padding_right)
|
28 |
+
if length <= max_pad:
|
29 |
+
extra_pad = max_pad - length + 1
|
30 |
+
extra_pad_right = min(padding_right, extra_pad)
|
31 |
+
extra_pad_left = extra_pad - extra_pad_right
|
32 |
+
paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
|
33 |
+
x = F.pad(x, (extra_pad_left, extra_pad_right))
|
34 |
+
out = F.pad(x, paddings, mode, value)
|
35 |
+
assert out.shape[-1] == length + padding_left + padding_right
|
36 |
+
assert (out[..., padding_left: padding_left + length] == x0).all()
|
37 |
+
return out
|
38 |
+
|
39 |
+
class ScaledEmbedding(nn.Module):
|
40 |
+
"""
|
41 |
+
Boost learning rate for embeddings (with `scale`).
|
42 |
+
Also, can make embeddings continuous with `smooth`.
|
43 |
+
"""
|
44 |
+
def __init__(self, num_embeddings: int, embedding_dim: int,
|
45 |
+
scale: float = 10., smooth=False):
|
46 |
+
super().__init__()
|
47 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
48 |
+
if smooth:
|
49 |
+
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
50 |
+
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
|
51 |
+
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
52 |
+
self.embedding.weight.data[:] = weight
|
53 |
+
self.embedding.weight.data /= scale
|
54 |
+
self.scale = scale
|
55 |
+
|
56 |
+
@property
|
57 |
+
def weight(self):
|
58 |
+
return self.embedding.weight * self.scale
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
out = self.embedding(x) * self.scale
|
62 |
+
return out
|
63 |
+
|
64 |
+
|
65 |
+
class HEncLayer(nn.Module):
|
66 |
+
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
67 |
+
freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
|
68 |
+
rewrite=True):
|
69 |
+
"""Encoder layer. This used both by the time and the frequency branch.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
chin: number of input channels.
|
73 |
+
chout: number of output channels.
|
74 |
+
norm_groups: number of groups for group norm.
|
75 |
+
empty: used to make a layer with just the first conv. this is used
|
76 |
+
before merging the time and freq. branches.
|
77 |
+
freq: this is acting on frequencies.
|
78 |
+
dconv: insert DConv residual branches.
|
79 |
+
norm: use GroupNorm.
|
80 |
+
context: context size for the 1x1 conv.
|
81 |
+
dconv_kw: list of kwargs for the DConv class.
|
82 |
+
pad: pad the input. Padding is done so that the output size is
|
83 |
+
always the input size / stride.
|
84 |
+
rewrite: add 1x1 conv at the end of the layer.
|
85 |
+
"""
|
86 |
+
super().__init__()
|
87 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
88 |
+
if norm:
|
89 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
90 |
+
if pad:
|
91 |
+
pad = kernel_size // 4
|
92 |
+
else:
|
93 |
+
pad = 0
|
94 |
+
klass = nn.Conv1d
|
95 |
+
self.freq = freq
|
96 |
+
self.kernel_size = kernel_size
|
97 |
+
self.stride = stride
|
98 |
+
self.empty = empty
|
99 |
+
self.norm = norm
|
100 |
+
self.pad = pad
|
101 |
+
if freq:
|
102 |
+
kernel_size = [kernel_size, 1]
|
103 |
+
stride = [stride, 1]
|
104 |
+
pad = [pad, 0]
|
105 |
+
klass = nn.Conv2d
|
106 |
+
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
107 |
+
if self.empty:
|
108 |
+
return
|
109 |
+
self.norm1 = norm_fn(chout)
|
110 |
+
self.rewrite = None
|
111 |
+
if rewrite:
|
112 |
+
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
113 |
+
self.norm2 = norm_fn(2 * chout)
|
114 |
+
|
115 |
+
self.dconv = None
|
116 |
+
if dconv:
|
117 |
+
self.dconv = DConv(chout, **dconv_kw)
|
118 |
+
|
119 |
+
def forward(self, x, inject=None):
|
120 |
+
"""
|
121 |
+
`inject` is used to inject the result from the time branch into the frequency branch,
|
122 |
+
when both have the same stride.
|
123 |
+
"""
|
124 |
+
if not self.freq and x.dim() == 4:
|
125 |
+
B, C, Fr, T = x.shape
|
126 |
+
x = x.view(B, -1, T)
|
127 |
+
|
128 |
+
if not self.freq:
|
129 |
+
le = x.shape[-1]
|
130 |
+
if not le % self.stride == 0:
|
131 |
+
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
132 |
+
y = self.conv(x)
|
133 |
+
if self.empty:
|
134 |
+
return y
|
135 |
+
if inject is not None:
|
136 |
+
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
|
137 |
+
if inject.dim() == 3 and y.dim() == 4:
|
138 |
+
inject = inject[:, :, None]
|
139 |
+
y = y + inject
|
140 |
+
y = F.gelu(self.norm1(y))
|
141 |
+
if self.dconv:
|
142 |
+
if self.freq:
|
143 |
+
B, C, Fr, T = y.shape
|
144 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
145 |
+
y = self.dconv(y)
|
146 |
+
if self.freq:
|
147 |
+
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
148 |
+
if self.rewrite:
|
149 |
+
z = self.norm2(self.rewrite(y))
|
150 |
+
z = F.glu(z, dim=1)
|
151 |
+
else:
|
152 |
+
z = y
|
153 |
+
return z
|
154 |
+
|
155 |
+
|
156 |
+
class MultiWrap(nn.Module):
|
157 |
+
"""
|
158 |
+
Takes one layer and replicate it N times. each replica will act
|
159 |
+
on a frequency band. All is done so that if the N replica have the same weights,
|
160 |
+
then this is exactly equivalent to applying the original module on all frequencies.
|
161 |
+
|
162 |
+
This is a bit over-engineered to avoid edge artifacts when splitting
|
163 |
+
the frequency bands, but it is possible the naive implementation would work as well...
|
164 |
+
"""
|
165 |
+
def __init__(self, layer, split_ratios):
|
166 |
+
"""
|
167 |
+
Args:
|
168 |
+
layer: module to clone, must be either HEncLayer or HDecLayer.
|
169 |
+
split_ratios: list of float indicating which ratio to keep for each band.
|
170 |
+
"""
|
171 |
+
super().__init__()
|
172 |
+
self.split_ratios = split_ratios
|
173 |
+
self.layers = nn.ModuleList()
|
174 |
+
self.conv = isinstance(layer, HEncLayer)
|
175 |
+
assert not layer.norm
|
176 |
+
assert layer.freq
|
177 |
+
assert layer.pad
|
178 |
+
if not self.conv:
|
179 |
+
assert not layer.context_freq
|
180 |
+
for k in range(len(split_ratios) + 1):
|
181 |
+
lay = deepcopy(layer)
|
182 |
+
if self.conv:
|
183 |
+
lay.conv.padding = (0, 0)
|
184 |
+
else:
|
185 |
+
lay.pad = False
|
186 |
+
for m in lay.modules():
|
187 |
+
if hasattr(m, 'reset_parameters'):
|
188 |
+
m.reset_parameters()
|
189 |
+
self.layers.append(lay)
|
190 |
+
|
191 |
+
def forward(self, x, skip=None, length=None):
|
192 |
+
B, C, Fr, T = x.shape
|
193 |
+
|
194 |
+
ratios = list(self.split_ratios) + [1]
|
195 |
+
start = 0
|
196 |
+
outs = []
|
197 |
+
for ratio, layer in zip(ratios, self.layers):
|
198 |
+
if self.conv:
|
199 |
+
pad = layer.kernel_size // 4
|
200 |
+
if ratio == 1:
|
201 |
+
limit = Fr
|
202 |
+
frames = -1
|
203 |
+
else:
|
204 |
+
limit = int(round(Fr * ratio))
|
205 |
+
le = limit - start
|
206 |
+
if start == 0:
|
207 |
+
le += pad
|
208 |
+
frames = round((le - layer.kernel_size) / layer.stride + 1)
|
209 |
+
limit = start + (frames - 1) * layer.stride + layer.kernel_size
|
210 |
+
if start == 0:
|
211 |
+
limit -= pad
|
212 |
+
assert limit - start > 0, (limit, start)
|
213 |
+
assert limit <= Fr, (limit, Fr)
|
214 |
+
y = x[:, :, start:limit, :]
|
215 |
+
if start == 0:
|
216 |
+
y = F.pad(y, (0, 0, pad, 0))
|
217 |
+
if ratio == 1:
|
218 |
+
y = F.pad(y, (0, 0, 0, pad))
|
219 |
+
outs.append(layer(y))
|
220 |
+
start = limit - layer.kernel_size + layer.stride
|
221 |
+
else:
|
222 |
+
if ratio == 1:
|
223 |
+
limit = Fr
|
224 |
+
else:
|
225 |
+
limit = int(round(Fr * ratio))
|
226 |
+
last = layer.last
|
227 |
+
layer.last = True
|
228 |
+
|
229 |
+
y = x[:, :, start:limit]
|
230 |
+
s = skip[:, :, start:limit]
|
231 |
+
out, _ = layer(y, s, None)
|
232 |
+
if outs:
|
233 |
+
outs[-1][:, :, -layer.stride:] += (
|
234 |
+
out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
|
235 |
+
out = out[:, :, layer.stride:]
|
236 |
+
if ratio == 1:
|
237 |
+
out = out[:, :, :-layer.stride // 2, :]
|
238 |
+
if start == 0:
|
239 |
+
out = out[:, :, layer.stride // 2:, :]
|
240 |
+
outs.append(out)
|
241 |
+
layer.last = last
|
242 |
+
start = limit
|
243 |
+
out = torch.cat(outs, dim=2)
|
244 |
+
if not self.conv and not last:
|
245 |
+
out = F.gelu(out)
|
246 |
+
if self.conv:
|
247 |
+
return out
|
248 |
+
else:
|
249 |
+
return out, None
|
250 |
+
|
251 |
+
|
252 |
+
class HDecLayer(nn.Module):
|
253 |
+
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
254 |
+
freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
|
255 |
+
context_freq=True, rewrite=True):
|
256 |
+
"""
|
257 |
+
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
|
258 |
+
"""
|
259 |
+
super().__init__()
|
260 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
261 |
+
if norm:
|
262 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
263 |
+
if pad:
|
264 |
+
pad = kernel_size // 4
|
265 |
+
else:
|
266 |
+
pad = 0
|
267 |
+
self.pad = pad
|
268 |
+
self.last = last
|
269 |
+
self.freq = freq
|
270 |
+
self.chin = chin
|
271 |
+
self.empty = empty
|
272 |
+
self.stride = stride
|
273 |
+
self.kernel_size = kernel_size
|
274 |
+
self.norm = norm
|
275 |
+
self.context_freq = context_freq
|
276 |
+
klass = nn.Conv1d
|
277 |
+
klass_tr = nn.ConvTranspose1d
|
278 |
+
if freq:
|
279 |
+
kernel_size = [kernel_size, 1]
|
280 |
+
stride = [stride, 1]
|
281 |
+
klass = nn.Conv2d
|
282 |
+
klass_tr = nn.ConvTranspose2d
|
283 |
+
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
284 |
+
self.norm2 = norm_fn(chout)
|
285 |
+
if self.empty:
|
286 |
+
return
|
287 |
+
self.rewrite = None
|
288 |
+
if rewrite:
|
289 |
+
if context_freq:
|
290 |
+
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
291 |
+
else:
|
292 |
+
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
|
293 |
+
[0, context])
|
294 |
+
self.norm1 = norm_fn(2 * chin)
|
295 |
+
|
296 |
+
self.dconv = None
|
297 |
+
if dconv:
|
298 |
+
self.dconv = DConv(chin, **dconv_kw)
|
299 |
+
|
300 |
+
def forward(self, x, skip, length):
|
301 |
+
if self.freq and x.dim() == 3:
|
302 |
+
B, C, T = x.shape
|
303 |
+
x = x.view(B, self.chin, -1, T)
|
304 |
+
|
305 |
+
if not self.empty:
|
306 |
+
x = x + skip
|
307 |
+
|
308 |
+
if self.rewrite:
|
309 |
+
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
310 |
+
else:
|
311 |
+
y = x
|
312 |
+
if self.dconv:
|
313 |
+
if self.freq:
|
314 |
+
B, C, Fr, T = y.shape
|
315 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
316 |
+
y = self.dconv(y)
|
317 |
+
if self.freq:
|
318 |
+
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
319 |
+
else:
|
320 |
+
y = x
|
321 |
+
assert skip is None
|
322 |
+
z = self.norm2(self.conv_tr(y))
|
323 |
+
if self.freq:
|
324 |
+
if self.pad:
|
325 |
+
z = z[..., self.pad:-self.pad, :]
|
326 |
+
else:
|
327 |
+
z = z[..., self.pad:self.pad + length]
|
328 |
+
assert z.shape[-1] == length, (z.shape[-1], length)
|
329 |
+
if not self.last:
|
330 |
+
z = F.gelu(z)
|
331 |
+
return z, y
|
332 |
+
|
333 |
+
|
334 |
+
class HDemucs(nn.Module):
|
335 |
+
"""
|
336 |
+
Spectrogram and hybrid Demucs model.
|
337 |
+
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
338 |
+
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
339 |
+
Frequency layers can still access information across time steps thanks to the DConv residual.
|
340 |
+
|
341 |
+
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
342 |
+
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
343 |
+
|
344 |
+
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
345 |
+
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
346 |
+
Open Unmix implementation [Stoter et al. 2019].
|
347 |
+
|
348 |
+
The loss is always on the temporal domain, by backpropagating through the above
|
349 |
+
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
350 |
+
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
351 |
+
contribution, without changing the one from the waveform, which will lead to worse performance.
|
352 |
+
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
353 |
+
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
354 |
+
hybrid models.
|
355 |
+
|
356 |
+
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
357 |
+
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
358 |
+
|
359 |
+
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
360 |
+
"""
|
361 |
+
@capture_init
|
362 |
+
def __init__(self,
|
363 |
+
sources,
|
364 |
+
# Channels
|
365 |
+
audio_channels=2,
|
366 |
+
channels=48,
|
367 |
+
channels_time=None,
|
368 |
+
growth=2,
|
369 |
+
# STFT
|
370 |
+
nfft=4096,
|
371 |
+
wiener_iters=0,
|
372 |
+
end_iters=0,
|
373 |
+
wiener_residual=False,
|
374 |
+
cac=True,
|
375 |
+
# Main structure
|
376 |
+
depth=6,
|
377 |
+
rewrite=True,
|
378 |
+
hybrid=True,
|
379 |
+
hybrid_old=False,
|
380 |
+
# Frequency branch
|
381 |
+
multi_freqs=None,
|
382 |
+
multi_freqs_depth=2,
|
383 |
+
freq_emb=0.2,
|
384 |
+
emb_scale=10,
|
385 |
+
emb_smooth=True,
|
386 |
+
# Convolutions
|
387 |
+
kernel_size=8,
|
388 |
+
time_stride=2,
|
389 |
+
stride=4,
|
390 |
+
context=1,
|
391 |
+
context_enc=0,
|
392 |
+
# Normalization
|
393 |
+
norm_starts=4,
|
394 |
+
norm_groups=4,
|
395 |
+
# DConv residual branch
|
396 |
+
dconv_mode=1,
|
397 |
+
dconv_depth=2,
|
398 |
+
dconv_comp=4,
|
399 |
+
dconv_attn=4,
|
400 |
+
dconv_lstm=4,
|
401 |
+
dconv_init=1e-4,
|
402 |
+
# Weight init
|
403 |
+
rescale=0.1,
|
404 |
+
# Metadata
|
405 |
+
samplerate=44100,
|
406 |
+
segment=4 * 10):
|
407 |
+
|
408 |
+
"""
|
409 |
+
Args:
|
410 |
+
sources (list[str]): list of source names.
|
411 |
+
audio_channels (int): input/output audio channels.
|
412 |
+
channels (int): initial number of hidden channels.
|
413 |
+
channels_time: if not None, use a different `channels` value for the time branch.
|
414 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
415 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
416 |
+
various shape parameters and will not work out of the box for hybrid models.
|
417 |
+
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
418 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
419 |
+
wiener_residual: add residual source before wiener filtering.
|
420 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
421 |
+
in input and output. no further processing is done before ISTFT.
|
422 |
+
depth (int): number of layers in the encoder and in the decoder.
|
423 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
424 |
+
hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
|
425 |
+
hybrid_old: some models trained for MDX had a padding bug. This replicates
|
426 |
+
this bug to avoid retraining them.
|
427 |
+
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
428 |
+
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
429 |
+
layers will be wrapped.
|
430 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
431 |
+
the actual value controls the weight of the embedding.
|
432 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
433 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
434 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
435 |
+
stride: stride for encoder and decoder layers.
|
436 |
+
time_stride: stride for the final time layer, after the merge.
|
437 |
+
context: context for 1x1 conv in the decoder.
|
438 |
+
context_enc: context for 1x1 conv in the encoder.
|
439 |
+
norm_starts: layer at which group norm starts being used.
|
440 |
+
decoder layers are numbered in reverse order.
|
441 |
+
norm_groups: number of groups for group norm.
|
442 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
443 |
+
dconv_depth: depth of residual DConv branch.
|
444 |
+
dconv_comp: compression of DConv branch.
|
445 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
446 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
447 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
448 |
+
rescale: weight recaling trick
|
449 |
+
|
450 |
+
"""
|
451 |
+
super().__init__()
|
452 |
+
|
453 |
+
self.cac = cac
|
454 |
+
self.wiener_residual = wiener_residual
|
455 |
+
self.audio_channels = audio_channels
|
456 |
+
self.sources = sources
|
457 |
+
self.kernel_size = kernel_size
|
458 |
+
self.context = context
|
459 |
+
self.stride = stride
|
460 |
+
self.depth = depth
|
461 |
+
self.channels = channels
|
462 |
+
self.samplerate = samplerate
|
463 |
+
self.segment = segment
|
464 |
+
|
465 |
+
self.nfft = nfft
|
466 |
+
self.hop_length = nfft // 4
|
467 |
+
self.wiener_iters = wiener_iters
|
468 |
+
self.end_iters = end_iters
|
469 |
+
self.freq_emb = None
|
470 |
+
self.hybrid = hybrid
|
471 |
+
self.hybrid_old = hybrid_old
|
472 |
+
if hybrid_old:
|
473 |
+
assert hybrid, "hybrid_old must come with hybrid=True"
|
474 |
+
if hybrid:
|
475 |
+
assert wiener_iters == end_iters
|
476 |
+
|
477 |
+
self.encoder = nn.ModuleList()
|
478 |
+
self.decoder = nn.ModuleList()
|
479 |
+
|
480 |
+
if hybrid:
|
481 |
+
self.tencoder = nn.ModuleList()
|
482 |
+
self.tdecoder = nn.ModuleList()
|
483 |
+
|
484 |
+
chin = audio_channels
|
485 |
+
chin_z = chin # number of channels for the freq branch
|
486 |
+
if self.cac:
|
487 |
+
chin_z *= 2
|
488 |
+
chout = channels_time or channels
|
489 |
+
chout_z = channels
|
490 |
+
freqs = nfft // 2
|
491 |
+
|
492 |
+
for index in range(depth):
|
493 |
+
lstm = index >= dconv_lstm
|
494 |
+
attn = index >= dconv_attn
|
495 |
+
norm = index >= norm_starts
|
496 |
+
freq = freqs > 1
|
497 |
+
stri = stride
|
498 |
+
ker = kernel_size
|
499 |
+
if not freq:
|
500 |
+
assert freqs == 1
|
501 |
+
ker = time_stride * 2
|
502 |
+
stri = time_stride
|
503 |
+
|
504 |
+
pad = True
|
505 |
+
last_freq = False
|
506 |
+
if freq and freqs <= kernel_size:
|
507 |
+
ker = freqs
|
508 |
+
pad = False
|
509 |
+
last_freq = True
|
510 |
+
|
511 |
+
kw = {
|
512 |
+
'kernel_size': ker,
|
513 |
+
'stride': stri,
|
514 |
+
'freq': freq,
|
515 |
+
'pad': pad,
|
516 |
+
'norm': norm,
|
517 |
+
'rewrite': rewrite,
|
518 |
+
'norm_groups': norm_groups,
|
519 |
+
'dconv_kw': {
|
520 |
+
'lstm': lstm,
|
521 |
+
'attn': attn,
|
522 |
+
'depth': dconv_depth,
|
523 |
+
'compress': dconv_comp,
|
524 |
+
'init': dconv_init,
|
525 |
+
'gelu': True,
|
526 |
+
}
|
527 |
+
}
|
528 |
+
kwt = dict(kw)
|
529 |
+
kwt['freq'] = 0
|
530 |
+
kwt['kernel_size'] = kernel_size
|
531 |
+
kwt['stride'] = stride
|
532 |
+
kwt['pad'] = True
|
533 |
+
kw_dec = dict(kw)
|
534 |
+
multi = False
|
535 |
+
if multi_freqs and index < multi_freqs_depth:
|
536 |
+
multi = True
|
537 |
+
kw_dec['context_freq'] = False
|
538 |
+
|
539 |
+
if last_freq:
|
540 |
+
chout_z = max(chout, chout_z)
|
541 |
+
chout = chout_z
|
542 |
+
|
543 |
+
enc = HEncLayer(chin_z, chout_z,
|
544 |
+
dconv=dconv_mode & 1, context=context_enc, **kw)
|
545 |
+
if hybrid and freq:
|
546 |
+
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
|
547 |
+
empty=last_freq, **kwt)
|
548 |
+
self.tencoder.append(tenc)
|
549 |
+
|
550 |
+
if multi:
|
551 |
+
enc = MultiWrap(enc, multi_freqs)
|
552 |
+
self.encoder.append(enc)
|
553 |
+
if index == 0:
|
554 |
+
chin = self.audio_channels * len(self.sources)
|
555 |
+
chin_z = chin
|
556 |
+
if self.cac:
|
557 |
+
chin_z *= 2
|
558 |
+
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
|
559 |
+
last=index == 0, context=context, **kw_dec)
|
560 |
+
if multi:
|
561 |
+
dec = MultiWrap(dec, multi_freqs)
|
562 |
+
if hybrid and freq:
|
563 |
+
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
|
564 |
+
last=index == 0, context=context, **kwt)
|
565 |
+
self.tdecoder.insert(0, tdec)
|
566 |
+
self.decoder.insert(0, dec)
|
567 |
+
|
568 |
+
chin = chout
|
569 |
+
chin_z = chout_z
|
570 |
+
chout = int(growth * chout)
|
571 |
+
chout_z = int(growth * chout_z)
|
572 |
+
if freq:
|
573 |
+
if freqs <= kernel_size:
|
574 |
+
freqs = 1
|
575 |
+
else:
|
576 |
+
freqs //= stride
|
577 |
+
if index == 0 and freq_emb:
|
578 |
+
self.freq_emb = ScaledEmbedding(
|
579 |
+
freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
580 |
+
self.freq_emb_scale = freq_emb
|
581 |
+
|
582 |
+
if rescale:
|
583 |
+
rescale_module(self, reference=rescale)
|
584 |
+
|
585 |
+
def _spec(self, x):
|
586 |
+
hl = self.hop_length
|
587 |
+
nfft = self.nfft
|
588 |
+
x0 = x # noqa
|
589 |
+
|
590 |
+
if self.hybrid:
|
591 |
+
# We re-pad the signal in order to keep the property
|
592 |
+
# that the size of the output is exactly the size of the input
|
593 |
+
# divided by the stride (here hop_length), when divisible.
|
594 |
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
595 |
+
# which is not supported by torch.stft.
|
596 |
+
# Having all convolution operations follow this convention allow to easily
|
597 |
+
# align the time and frequency branches later on.
|
598 |
+
assert hl == nfft // 4
|
599 |
+
le = int(math.ceil(x.shape[-1] / hl))
|
600 |
+
pad = hl // 2 * 3
|
601 |
+
if not self.hybrid_old:
|
602 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
|
603 |
+
else:
|
604 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
|
605 |
+
|
606 |
+
z = spectro(x, nfft, hl)[..., :-1, :]
|
607 |
+
if self.hybrid:
|
608 |
+
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
609 |
+
z = z[..., 2:2+le]
|
610 |
+
return z
|
611 |
+
|
612 |
+
def _ispec(self, z, length=None, scale=0):
|
613 |
+
hl = self.hop_length // (4 ** scale)
|
614 |
+
z = F.pad(z, (0, 0, 0, 1))
|
615 |
+
if self.hybrid:
|
616 |
+
z = F.pad(z, (2, 2))
|
617 |
+
pad = hl // 2 * 3
|
618 |
+
if not self.hybrid_old:
|
619 |
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
620 |
+
else:
|
621 |
+
le = hl * int(math.ceil(length / hl))
|
622 |
+
x = ispectro(z, hl, length=le)
|
623 |
+
if not self.hybrid_old:
|
624 |
+
x = x[..., pad:pad + length]
|
625 |
+
else:
|
626 |
+
x = x[..., :length]
|
627 |
+
else:
|
628 |
+
x = ispectro(z, hl, length)
|
629 |
+
return x
|
630 |
+
|
631 |
+
def _magnitude(self, z):
|
632 |
+
# return the magnitude of the spectrogram, except when cac is True,
|
633 |
+
# in which case we just move the complex dimension to the channel one.
|
634 |
+
if self.cac:
|
635 |
+
B, C, Fr, T = z.shape
|
636 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
637 |
+
m = m.reshape(B, C * 2, Fr, T)
|
638 |
+
else:
|
639 |
+
m = z.abs()
|
640 |
+
return m
|
641 |
+
|
642 |
+
def _mask(self, z, m):
|
643 |
+
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
644 |
+
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
645 |
+
niters = self.wiener_iters
|
646 |
+
if self.cac:
|
647 |
+
B, S, C, Fr, T = m.shape
|
648 |
+
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
649 |
+
out = torch.view_as_complex(out.contiguous())
|
650 |
+
return out
|
651 |
+
if self.training:
|
652 |
+
niters = self.end_iters
|
653 |
+
if niters < 0:
|
654 |
+
z = z[:, None]
|
655 |
+
return z / (1e-8 + z.abs()) * m
|
656 |
+
else:
|
657 |
+
return self._wiener(m, z, niters)
|
658 |
+
|
659 |
+
def _wiener(self, mag_out, mix_stft, niters):
|
660 |
+
# apply wiener filtering from OpenUnmix.
|
661 |
+
init = mix_stft.dtype
|
662 |
+
wiener_win_len = 300
|
663 |
+
residual = self.wiener_residual
|
664 |
+
|
665 |
+
B, S, C, Fq, T = mag_out.shape
|
666 |
+
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
667 |
+
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
668 |
+
|
669 |
+
outs = []
|
670 |
+
for sample in range(B):
|
671 |
+
pos = 0
|
672 |
+
out = []
|
673 |
+
for pos in range(0, T, wiener_win_len):
|
674 |
+
frame = slice(pos, pos + wiener_win_len)
|
675 |
+
z_out = wiener(
|
676 |
+
mag_out[sample, frame], mix_stft[sample, frame], niters,
|
677 |
+
residual=residual)
|
678 |
+
out.append(z_out.transpose(-1, -2))
|
679 |
+
outs.append(torch.cat(out, dim=0))
|
680 |
+
out = torch.view_as_complex(torch.stack(outs, 0))
|
681 |
+
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
682 |
+
if residual:
|
683 |
+
out = out[:, :-1]
|
684 |
+
assert list(out.shape) == [B, S, C, Fq, T]
|
685 |
+
return out.to(init)
|
686 |
+
|
687 |
+
def forward(self, mix):
|
688 |
+
x = mix
|
689 |
+
length = x.shape[-1]
|
690 |
+
|
691 |
+
z = self._spec(mix)
|
692 |
+
mag = self._magnitude(z).to(mix.device)
|
693 |
+
x = mag
|
694 |
+
|
695 |
+
B, C, Fq, T = x.shape
|
696 |
+
|
697 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
698 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
699 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
700 |
+
x = (x - mean) / (1e-5 + std)
|
701 |
+
# x will be the freq. branch input.
|
702 |
+
|
703 |
+
if self.hybrid:
|
704 |
+
# Prepare the time branch input.
|
705 |
+
xt = mix
|
706 |
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
707 |
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
708 |
+
xt = (xt - meant) / (1e-5 + stdt)
|
709 |
+
|
710 |
+
# okay, this is a giant mess I know...
|
711 |
+
saved = [] # skip connections, freq.
|
712 |
+
saved_t = [] # skip connections, time.
|
713 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
714 |
+
lengths_t = [] # saved lengths for time branch.
|
715 |
+
for idx, encode in enumerate(self.encoder):
|
716 |
+
lengths.append(x.shape[-1])
|
717 |
+
inject = None
|
718 |
+
if self.hybrid and idx < len(self.tencoder):
|
719 |
+
# we have not yet merged branches.
|
720 |
+
lengths_t.append(xt.shape[-1])
|
721 |
+
tenc = self.tencoder[idx]
|
722 |
+
xt = tenc(xt)
|
723 |
+
if not tenc.empty:
|
724 |
+
# save for skip connection
|
725 |
+
saved_t.append(xt)
|
726 |
+
else:
|
727 |
+
# tenc contains just the first conv., so that now time and freq.
|
728 |
+
# branches have the same shape and can be merged.
|
729 |
+
inject = xt
|
730 |
+
x = encode(x, inject)
|
731 |
+
if idx == 0 and self.freq_emb is not None:
|
732 |
+
# add frequency embedding to allow for non equivariant convolutions
|
733 |
+
# over the frequency axis.
|
734 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
735 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
736 |
+
x = x + self.freq_emb_scale * emb
|
737 |
+
|
738 |
+
saved.append(x)
|
739 |
+
|
740 |
+
x = torch.zeros_like(x)
|
741 |
+
if self.hybrid:
|
742 |
+
xt = torch.zeros_like(x)
|
743 |
+
# initialize everything to zero (signal will go through u-net skips).
|
744 |
+
|
745 |
+
for idx, decode in enumerate(self.decoder):
|
746 |
+
skip = saved.pop(-1)
|
747 |
+
x, pre = decode(x, skip, lengths.pop(-1))
|
748 |
+
# `pre` contains the output just before final transposed convolution,
|
749 |
+
# which is used when the freq. and time branch separate.
|
750 |
+
|
751 |
+
if self.hybrid:
|
752 |
+
offset = self.depth - len(self.tdecoder)
|
753 |
+
if self.hybrid and idx >= offset:
|
754 |
+
tdec = self.tdecoder[idx - offset]
|
755 |
+
length_t = lengths_t.pop(-1)
|
756 |
+
if tdec.empty:
|
757 |
+
assert pre.shape[2] == 1, pre.shape
|
758 |
+
pre = pre[:, :, 0]
|
759 |
+
xt, _ = tdec(pre, None, length_t)
|
760 |
+
else:
|
761 |
+
skip = saved_t.pop(-1)
|
762 |
+
xt, _ = tdec(xt, skip, length_t)
|
763 |
+
|
764 |
+
# Let's make sure we used all stored skip connections.
|
765 |
+
assert len(saved) == 0
|
766 |
+
assert len(lengths_t) == 0
|
767 |
+
assert len(saved_t) == 0
|
768 |
+
|
769 |
+
S = len(self.sources)
|
770 |
+
x = x.view(B, S, -1, Fq, T)
|
771 |
+
x = x * std[:, None] + mean[:, None]
|
772 |
+
|
773 |
+
# to cpu as non-cuda GPUs don't support complex numbers
|
774 |
+
# demucs issue #435 ##432
|
775 |
+
# NOTE: in this case z already is on cpu
|
776 |
+
# TODO: remove this when mps supports complex numbers
|
777 |
+
|
778 |
+
device_type = x.device.type
|
779 |
+
device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type
|
780 |
+
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
781 |
+
|
782 |
+
if x_is_other_gpu:
|
783 |
+
x = x.cpu()
|
784 |
+
|
785 |
+
zout = self._mask(z, x)
|
786 |
+
x = self._ispec(zout, length)
|
787 |
+
|
788 |
+
# back to other device
|
789 |
+
if x_is_other_gpu:
|
790 |
+
x = x.to(device_load)
|
791 |
+
|
792 |
+
if self.hybrid:
|
793 |
+
xt = xt.view(B, S, -1, length)
|
794 |
+
xt = xt * stdt[:, None] + meant[:, None]
|
795 |
+
x = xt + x
|
796 |
+
return x
|
demucs/htdemucs.py
ADDED
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# First author is Simon Rouard.
|
7 |
+
"""
|
8 |
+
This code contains the spectrogram and Hybrid version of Demucs.
|
9 |
+
"""
|
10 |
+
import math
|
11 |
+
|
12 |
+
from .filtering import wiener
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from fractions import Fraction
|
17 |
+
from einops import rearrange
|
18 |
+
|
19 |
+
from .transformer import CrossTransformerEncoder
|
20 |
+
|
21 |
+
from .demucs import rescale_module
|
22 |
+
from .states import capture_init
|
23 |
+
from .spec import spectro, ispectro
|
24 |
+
from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
|
25 |
+
|
26 |
+
|
27 |
+
class HTDemucs(nn.Module):
|
28 |
+
"""
|
29 |
+
Spectrogram and hybrid Demucs model.
|
30 |
+
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
31 |
+
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
32 |
+
Frequency layers can still access information across time steps thanks to the DConv residual.
|
33 |
+
|
34 |
+
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
35 |
+
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
36 |
+
|
37 |
+
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
38 |
+
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
39 |
+
Open Unmix implementation [Stoter et al. 2019].
|
40 |
+
|
41 |
+
The loss is always on the temporal domain, by backpropagating through the above
|
42 |
+
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
43 |
+
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
44 |
+
contribution, without changing the one from the waveform, which will lead to worse performance.
|
45 |
+
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
46 |
+
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
47 |
+
hybrid models.
|
48 |
+
|
49 |
+
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
50 |
+
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
51 |
+
|
52 |
+
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
53 |
+
"""
|
54 |
+
|
55 |
+
@capture_init
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
sources,
|
59 |
+
# Channels
|
60 |
+
audio_channels=2,
|
61 |
+
channels=48,
|
62 |
+
channels_time=None,
|
63 |
+
growth=2,
|
64 |
+
# STFT
|
65 |
+
nfft=4096,
|
66 |
+
wiener_iters=0,
|
67 |
+
end_iters=0,
|
68 |
+
wiener_residual=False,
|
69 |
+
cac=True,
|
70 |
+
# Main structure
|
71 |
+
depth=4,
|
72 |
+
rewrite=True,
|
73 |
+
# Frequency branch
|
74 |
+
multi_freqs=None,
|
75 |
+
multi_freqs_depth=3,
|
76 |
+
freq_emb=0.2,
|
77 |
+
emb_scale=10,
|
78 |
+
emb_smooth=True,
|
79 |
+
# Convolutions
|
80 |
+
kernel_size=8,
|
81 |
+
time_stride=2,
|
82 |
+
stride=4,
|
83 |
+
context=1,
|
84 |
+
context_enc=0,
|
85 |
+
# Normalization
|
86 |
+
norm_starts=4,
|
87 |
+
norm_groups=4,
|
88 |
+
# DConv residual branch
|
89 |
+
dconv_mode=1,
|
90 |
+
dconv_depth=2,
|
91 |
+
dconv_comp=8,
|
92 |
+
dconv_init=1e-3,
|
93 |
+
# Before the Transformer
|
94 |
+
bottom_channels=0,
|
95 |
+
# Transformer
|
96 |
+
t_layers=5,
|
97 |
+
t_emb="sin",
|
98 |
+
t_hidden_scale=4.0,
|
99 |
+
t_heads=8,
|
100 |
+
t_dropout=0.0,
|
101 |
+
t_max_positions=10000,
|
102 |
+
t_norm_in=True,
|
103 |
+
t_norm_in_group=False,
|
104 |
+
t_group_norm=False,
|
105 |
+
t_norm_first=True,
|
106 |
+
t_norm_out=True,
|
107 |
+
t_max_period=10000.0,
|
108 |
+
t_weight_decay=0.0,
|
109 |
+
t_lr=None,
|
110 |
+
t_layer_scale=True,
|
111 |
+
t_gelu=True,
|
112 |
+
t_weight_pos_embed=1.0,
|
113 |
+
t_sin_random_shift=0,
|
114 |
+
t_cape_mean_normalize=True,
|
115 |
+
t_cape_augment=True,
|
116 |
+
t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
|
117 |
+
t_sparse_self_attn=False,
|
118 |
+
t_sparse_cross_attn=False,
|
119 |
+
t_mask_type="diag",
|
120 |
+
t_mask_random_seed=42,
|
121 |
+
t_sparse_attn_window=500,
|
122 |
+
t_global_window=100,
|
123 |
+
t_sparsity=0.95,
|
124 |
+
t_auto_sparsity=False,
|
125 |
+
# ------ Particuliar parameters
|
126 |
+
t_cross_first=False,
|
127 |
+
# Weight init
|
128 |
+
rescale=0.1,
|
129 |
+
# Metadata
|
130 |
+
samplerate=44100,
|
131 |
+
segment=10,
|
132 |
+
use_train_segment=True,
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
Args:
|
136 |
+
sources (list[str]): list of source names.
|
137 |
+
audio_channels (int): input/output audio channels.
|
138 |
+
channels (int): initial number of hidden channels.
|
139 |
+
channels_time: if not None, use a different `channels` value for the time branch.
|
140 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
141 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
142 |
+
various shape parameters and will not work out of the box for hybrid models.
|
143 |
+
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
144 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
145 |
+
wiener_residual: add residual source before wiener filtering.
|
146 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
147 |
+
in input and output. no further processing is done before ISTFT.
|
148 |
+
depth (int): number of layers in the encoder and in the decoder.
|
149 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
150 |
+
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
151 |
+
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
152 |
+
layers will be wrapped.
|
153 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
154 |
+
the actual value controls the weight of the embedding.
|
155 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
156 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
157 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
158 |
+
stride: stride for encoder and decoder layers.
|
159 |
+
time_stride: stride for the final time layer, after the merge.
|
160 |
+
context: context for 1x1 conv in the decoder.
|
161 |
+
context_enc: context for 1x1 conv in the encoder.
|
162 |
+
norm_starts: layer at which group norm starts being used.
|
163 |
+
decoder layers are numbered in reverse order.
|
164 |
+
norm_groups: number of groups for group norm.
|
165 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
166 |
+
dconv_depth: depth of residual DConv branch.
|
167 |
+
dconv_comp: compression of DConv branch.
|
168 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
169 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
170 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
171 |
+
bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
|
172 |
+
transformer in order to change the number of channels
|
173 |
+
t_layers: number of layers in each branch (waveform and spec) of the transformer
|
174 |
+
t_emb: "sin", "cape" or "scaled"
|
175 |
+
t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
|
176 |
+
for instance if C = 384 (the number of channels in the transformer) and
|
177 |
+
t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
|
178 |
+
384 * 4 = 1536
|
179 |
+
t_heads: number of heads for the transformer
|
180 |
+
t_dropout: dropout in the transformer
|
181 |
+
t_max_positions: max_positions for the "scaled" positional embedding, only
|
182 |
+
useful if t_emb="scaled"
|
183 |
+
t_norm_in: (bool) norm before addinf positional embedding and getting into the
|
184 |
+
transformer layers
|
185 |
+
t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
|
186 |
+
timesteps (GroupNorm with group=1)
|
187 |
+
t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
|
188 |
+
timesteps (GroupNorm with group=1)
|
189 |
+
t_norm_first: (bool) if True the norm is before the attention and before the FFN
|
190 |
+
t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
|
191 |
+
t_max_period: (float) denominator in the sinusoidal embedding expression
|
192 |
+
t_weight_decay: (float) weight decay for the transformer
|
193 |
+
t_lr: (float) specific learning rate for the transformer
|
194 |
+
t_layer_scale: (bool) Layer Scale for the transformer
|
195 |
+
t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
|
196 |
+
t_weight_pos_embed: (float) weighting of the positional embedding
|
197 |
+
t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
|
198 |
+
see: https://arxiv.org/abs/2106.03143
|
199 |
+
t_cape_augment: (bool) if t_emb="cape", must be True during training and False
|
200 |
+
during the inference, see: https://arxiv.org/abs/2106.03143
|
201 |
+
t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
|
202 |
+
see: https://arxiv.org/abs/2106.03143
|
203 |
+
t_sparse_self_attn: (bool) if True, the self attentions are sparse
|
204 |
+
t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
|
205 |
+
unless you designed really specific masks)
|
206 |
+
t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
|
207 |
+
with '_' between: i.e. "diag_jmask_random" (note that this is permutation
|
208 |
+
invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
|
209 |
+
t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
|
210 |
+
that generated the random part of the mask
|
211 |
+
t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
|
212 |
+
a key (j), the mask is True id |i-j|<=t_sparse_attn_window
|
213 |
+
t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
|
214 |
+
and mask[:, :t_global_window] will be True
|
215 |
+
t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
|
216 |
+
level of the random part of the mask.
|
217 |
+
t_cross_first: (bool) if True cross attention is the first layer of the
|
218 |
+
transformer (False seems to be better)
|
219 |
+
rescale: weight rescaling trick
|
220 |
+
use_train_segment: (bool) if True, the actual size that is used during the
|
221 |
+
training is used during inference.
|
222 |
+
"""
|
223 |
+
super().__init__()
|
224 |
+
self.cac = cac
|
225 |
+
self.wiener_residual = wiener_residual
|
226 |
+
self.audio_channels = audio_channels
|
227 |
+
self.sources = sources
|
228 |
+
self.kernel_size = kernel_size
|
229 |
+
self.context = context
|
230 |
+
self.stride = stride
|
231 |
+
self.depth = depth
|
232 |
+
self.bottom_channels = bottom_channels
|
233 |
+
self.channels = channels
|
234 |
+
self.samplerate = samplerate
|
235 |
+
self.segment = segment
|
236 |
+
self.use_train_segment = use_train_segment
|
237 |
+
self.nfft = nfft
|
238 |
+
self.hop_length = nfft // 4
|
239 |
+
self.wiener_iters = wiener_iters
|
240 |
+
self.end_iters = end_iters
|
241 |
+
self.freq_emb = None
|
242 |
+
assert wiener_iters == end_iters
|
243 |
+
|
244 |
+
self.encoder = nn.ModuleList()
|
245 |
+
self.decoder = nn.ModuleList()
|
246 |
+
|
247 |
+
self.tencoder = nn.ModuleList()
|
248 |
+
self.tdecoder = nn.ModuleList()
|
249 |
+
|
250 |
+
chin = audio_channels
|
251 |
+
chin_z = chin # number of channels for the freq branch
|
252 |
+
if self.cac:
|
253 |
+
chin_z *= 2
|
254 |
+
chout = channels_time or channels
|
255 |
+
chout_z = channels
|
256 |
+
freqs = nfft // 2
|
257 |
+
|
258 |
+
for index in range(depth):
|
259 |
+
norm = index >= norm_starts
|
260 |
+
freq = freqs > 1
|
261 |
+
stri = stride
|
262 |
+
ker = kernel_size
|
263 |
+
if not freq:
|
264 |
+
assert freqs == 1
|
265 |
+
ker = time_stride * 2
|
266 |
+
stri = time_stride
|
267 |
+
|
268 |
+
pad = True
|
269 |
+
last_freq = False
|
270 |
+
if freq and freqs <= kernel_size:
|
271 |
+
ker = freqs
|
272 |
+
pad = False
|
273 |
+
last_freq = True
|
274 |
+
|
275 |
+
kw = {
|
276 |
+
"kernel_size": ker,
|
277 |
+
"stride": stri,
|
278 |
+
"freq": freq,
|
279 |
+
"pad": pad,
|
280 |
+
"norm": norm,
|
281 |
+
"rewrite": rewrite,
|
282 |
+
"norm_groups": norm_groups,
|
283 |
+
"dconv_kw": {
|
284 |
+
"depth": dconv_depth,
|
285 |
+
"compress": dconv_comp,
|
286 |
+
"init": dconv_init,
|
287 |
+
"gelu": True,
|
288 |
+
},
|
289 |
+
}
|
290 |
+
kwt = dict(kw)
|
291 |
+
kwt["freq"] = 0
|
292 |
+
kwt["kernel_size"] = kernel_size
|
293 |
+
kwt["stride"] = stride
|
294 |
+
kwt["pad"] = True
|
295 |
+
kw_dec = dict(kw)
|
296 |
+
multi = False
|
297 |
+
if multi_freqs and index < multi_freqs_depth:
|
298 |
+
multi = True
|
299 |
+
kw_dec["context_freq"] = False
|
300 |
+
|
301 |
+
if last_freq:
|
302 |
+
chout_z = max(chout, chout_z)
|
303 |
+
chout = chout_z
|
304 |
+
|
305 |
+
enc = HEncLayer(
|
306 |
+
chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
|
307 |
+
)
|
308 |
+
if freq:
|
309 |
+
tenc = HEncLayer(
|
310 |
+
chin,
|
311 |
+
chout,
|
312 |
+
dconv=dconv_mode & 1,
|
313 |
+
context=context_enc,
|
314 |
+
empty=last_freq,
|
315 |
+
**kwt
|
316 |
+
)
|
317 |
+
self.tencoder.append(tenc)
|
318 |
+
|
319 |
+
if multi:
|
320 |
+
enc = MultiWrap(enc, multi_freqs)
|
321 |
+
self.encoder.append(enc)
|
322 |
+
if index == 0:
|
323 |
+
chin = self.audio_channels * len(self.sources)
|
324 |
+
chin_z = chin
|
325 |
+
if self.cac:
|
326 |
+
chin_z *= 2
|
327 |
+
dec = HDecLayer(
|
328 |
+
chout_z,
|
329 |
+
chin_z,
|
330 |
+
dconv=dconv_mode & 2,
|
331 |
+
last=index == 0,
|
332 |
+
context=context,
|
333 |
+
**kw_dec
|
334 |
+
)
|
335 |
+
if multi:
|
336 |
+
dec = MultiWrap(dec, multi_freqs)
|
337 |
+
if freq:
|
338 |
+
tdec = HDecLayer(
|
339 |
+
chout,
|
340 |
+
chin,
|
341 |
+
dconv=dconv_mode & 2,
|
342 |
+
empty=last_freq,
|
343 |
+
last=index == 0,
|
344 |
+
context=context,
|
345 |
+
**kwt
|
346 |
+
)
|
347 |
+
self.tdecoder.insert(0, tdec)
|
348 |
+
self.decoder.insert(0, dec)
|
349 |
+
|
350 |
+
chin = chout
|
351 |
+
chin_z = chout_z
|
352 |
+
chout = int(growth * chout)
|
353 |
+
chout_z = int(growth * chout_z)
|
354 |
+
if freq:
|
355 |
+
if freqs <= kernel_size:
|
356 |
+
freqs = 1
|
357 |
+
else:
|
358 |
+
freqs //= stride
|
359 |
+
if index == 0 and freq_emb:
|
360 |
+
self.freq_emb = ScaledEmbedding(
|
361 |
+
freqs, chin_z, smooth=emb_smooth, scale=emb_scale
|
362 |
+
)
|
363 |
+
self.freq_emb_scale = freq_emb
|
364 |
+
|
365 |
+
if rescale:
|
366 |
+
rescale_module(self, reference=rescale)
|
367 |
+
|
368 |
+
transformer_channels = channels * growth ** (depth - 1)
|
369 |
+
if bottom_channels:
|
370 |
+
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
371 |
+
self.channel_downsampler = nn.Conv1d(
|
372 |
+
bottom_channels, transformer_channels, 1
|
373 |
+
)
|
374 |
+
self.channel_upsampler_t = nn.Conv1d(
|
375 |
+
transformer_channels, bottom_channels, 1
|
376 |
+
)
|
377 |
+
self.channel_downsampler_t = nn.Conv1d(
|
378 |
+
bottom_channels, transformer_channels, 1
|
379 |
+
)
|
380 |
+
|
381 |
+
transformer_channels = bottom_channels
|
382 |
+
|
383 |
+
if t_layers > 0:
|
384 |
+
self.crosstransformer = CrossTransformerEncoder(
|
385 |
+
dim=transformer_channels,
|
386 |
+
emb=t_emb,
|
387 |
+
hidden_scale=t_hidden_scale,
|
388 |
+
num_heads=t_heads,
|
389 |
+
num_layers=t_layers,
|
390 |
+
cross_first=t_cross_first,
|
391 |
+
dropout=t_dropout,
|
392 |
+
max_positions=t_max_positions,
|
393 |
+
norm_in=t_norm_in,
|
394 |
+
norm_in_group=t_norm_in_group,
|
395 |
+
group_norm=t_group_norm,
|
396 |
+
norm_first=t_norm_first,
|
397 |
+
norm_out=t_norm_out,
|
398 |
+
max_period=t_max_period,
|
399 |
+
weight_decay=t_weight_decay,
|
400 |
+
lr=t_lr,
|
401 |
+
layer_scale=t_layer_scale,
|
402 |
+
gelu=t_gelu,
|
403 |
+
sin_random_shift=t_sin_random_shift,
|
404 |
+
weight_pos_embed=t_weight_pos_embed,
|
405 |
+
cape_mean_normalize=t_cape_mean_normalize,
|
406 |
+
cape_augment=t_cape_augment,
|
407 |
+
cape_glob_loc_scale=t_cape_glob_loc_scale,
|
408 |
+
sparse_self_attn=t_sparse_self_attn,
|
409 |
+
sparse_cross_attn=t_sparse_cross_attn,
|
410 |
+
mask_type=t_mask_type,
|
411 |
+
mask_random_seed=t_mask_random_seed,
|
412 |
+
sparse_attn_window=t_sparse_attn_window,
|
413 |
+
global_window=t_global_window,
|
414 |
+
sparsity=t_sparsity,
|
415 |
+
auto_sparsity=t_auto_sparsity,
|
416 |
+
)
|
417 |
+
else:
|
418 |
+
self.crosstransformer = None
|
419 |
+
|
420 |
+
def _spec(self, x):
|
421 |
+
hl = self.hop_length
|
422 |
+
nfft = self.nfft
|
423 |
+
x0 = x # noqa
|
424 |
+
|
425 |
+
# We re-pad the signal in order to keep the property
|
426 |
+
# that the size of the output is exactly the size of the input
|
427 |
+
# divided by the stride (here hop_length), when divisible.
|
428 |
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
429 |
+
# which is not supported by torch.stft.
|
430 |
+
# Having all convolution operations follow this convention allow to easily
|
431 |
+
# align the time and frequency branches later on.
|
432 |
+
assert hl == nfft // 4
|
433 |
+
le = int(math.ceil(x.shape[-1] / hl))
|
434 |
+
pad = hl // 2 * 3
|
435 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
|
436 |
+
|
437 |
+
z = spectro(x, nfft, hl)[..., :-1, :]
|
438 |
+
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
439 |
+
z = z[..., 2: 2 + le]
|
440 |
+
return z
|
441 |
+
|
442 |
+
def _ispec(self, z, length=None, scale=0):
|
443 |
+
hl = self.hop_length // (4**scale)
|
444 |
+
z = F.pad(z, (0, 0, 0, 1))
|
445 |
+
z = F.pad(z, (2, 2))
|
446 |
+
pad = hl // 2 * 3
|
447 |
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
448 |
+
x = ispectro(z, hl, length=le)
|
449 |
+
x = x[..., pad: pad + length]
|
450 |
+
return x
|
451 |
+
|
452 |
+
def _magnitude(self, z):
|
453 |
+
# return the magnitude of the spectrogram, except when cac is True,
|
454 |
+
# in which case we just move the complex dimension to the channel one.
|
455 |
+
if self.cac:
|
456 |
+
B, C, Fr, T = z.shape
|
457 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
458 |
+
m = m.reshape(B, C * 2, Fr, T)
|
459 |
+
else:
|
460 |
+
m = z.abs()
|
461 |
+
return m
|
462 |
+
|
463 |
+
def _mask(self, z, m):
|
464 |
+
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
465 |
+
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
466 |
+
niters = self.wiener_iters
|
467 |
+
if self.cac:
|
468 |
+
B, S, C, Fr, T = m.shape
|
469 |
+
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
470 |
+
out = torch.view_as_complex(out.contiguous())
|
471 |
+
return out
|
472 |
+
if self.training:
|
473 |
+
niters = self.end_iters
|
474 |
+
if niters < 0:
|
475 |
+
z = z[:, None]
|
476 |
+
return z / (1e-8 + z.abs()) * m
|
477 |
+
else:
|
478 |
+
return self._wiener(m, z, niters)
|
479 |
+
|
480 |
+
def _wiener(self, mag_out, mix_stft, niters):
|
481 |
+
# apply wiener filtering from OpenUnmix.
|
482 |
+
init = mix_stft.dtype
|
483 |
+
wiener_win_len = 300
|
484 |
+
residual = self.wiener_residual
|
485 |
+
|
486 |
+
B, S, C, Fq, T = mag_out.shape
|
487 |
+
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
488 |
+
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
489 |
+
|
490 |
+
outs = []
|
491 |
+
for sample in range(B):
|
492 |
+
pos = 0
|
493 |
+
out = []
|
494 |
+
for pos in range(0, T, wiener_win_len):
|
495 |
+
frame = slice(pos, pos + wiener_win_len)
|
496 |
+
z_out = wiener(
|
497 |
+
mag_out[sample, frame],
|
498 |
+
mix_stft[sample, frame],
|
499 |
+
niters,
|
500 |
+
residual=residual,
|
501 |
+
)
|
502 |
+
out.append(z_out.transpose(-1, -2))
|
503 |
+
outs.append(torch.cat(out, dim=0))
|
504 |
+
out = torch.view_as_complex(torch.stack(outs, 0))
|
505 |
+
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
506 |
+
if residual:
|
507 |
+
out = out[:, :-1]
|
508 |
+
assert list(out.shape) == [B, S, C, Fq, T]
|
509 |
+
return out.to(init)
|
510 |
+
|
511 |
+
def valid_length(self, length: int):
|
512 |
+
"""
|
513 |
+
Return a length that is appropriate for evaluation.
|
514 |
+
In our case, always return the training length, unless
|
515 |
+
it is smaller than the given length, in which case this
|
516 |
+
raises an error.
|
517 |
+
"""
|
518 |
+
if not self.use_train_segment:
|
519 |
+
return length
|
520 |
+
training_length = int(self.segment * self.samplerate)
|
521 |
+
if training_length < length:
|
522 |
+
raise ValueError(
|
523 |
+
f"Given length {length} is longer than "
|
524 |
+
f"training length {training_length}")
|
525 |
+
return training_length
|
526 |
+
|
527 |
+
def forward(self, mix):
|
528 |
+
length = mix.shape[-1]
|
529 |
+
length_pre_pad = None
|
530 |
+
if self.use_train_segment:
|
531 |
+
if self.training:
|
532 |
+
self.segment = Fraction(mix.shape[-1], self.samplerate)
|
533 |
+
else:
|
534 |
+
training_length = int(self.segment * self.samplerate)
|
535 |
+
if mix.shape[-1] < training_length:
|
536 |
+
length_pre_pad = mix.shape[-1]
|
537 |
+
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
538 |
+
z = self._spec(mix)
|
539 |
+
mag = self._magnitude(z).to(mix.device)
|
540 |
+
x = mag
|
541 |
+
|
542 |
+
B, C, Fq, T = x.shape
|
543 |
+
|
544 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
545 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
546 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
547 |
+
x = (x - mean) / (1e-5 + std)
|
548 |
+
# x will be the freq. branch input.
|
549 |
+
|
550 |
+
# Prepare the time branch input.
|
551 |
+
xt = mix
|
552 |
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
553 |
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
554 |
+
xt = (xt - meant) / (1e-5 + stdt)
|
555 |
+
|
556 |
+
# okay, this is a giant mess I know...
|
557 |
+
saved = [] # skip connections, freq.
|
558 |
+
saved_t = [] # skip connections, time.
|
559 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
560 |
+
lengths_t = [] # saved lengths for time branch.
|
561 |
+
for idx, encode in enumerate(self.encoder):
|
562 |
+
lengths.append(x.shape[-1])
|
563 |
+
inject = None
|
564 |
+
if idx < len(self.tencoder):
|
565 |
+
# we have not yet merged branches.
|
566 |
+
lengths_t.append(xt.shape[-1])
|
567 |
+
tenc = self.tencoder[idx]
|
568 |
+
xt = tenc(xt)
|
569 |
+
if not tenc.empty:
|
570 |
+
# save for skip connection
|
571 |
+
saved_t.append(xt)
|
572 |
+
else:
|
573 |
+
# tenc contains just the first conv., so that now time and freq.
|
574 |
+
# branches have the same shape and can be merged.
|
575 |
+
inject = xt
|
576 |
+
x = encode(x, inject)
|
577 |
+
if idx == 0 and self.freq_emb is not None:
|
578 |
+
# add frequency embedding to allow for non equivariant convolutions
|
579 |
+
# over the frequency axis.
|
580 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
581 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
582 |
+
x = x + self.freq_emb_scale * emb
|
583 |
+
|
584 |
+
saved.append(x)
|
585 |
+
if self.crosstransformer:
|
586 |
+
if self.bottom_channels:
|
587 |
+
b, c, f, t = x.shape
|
588 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
589 |
+
x = self.channel_upsampler(x)
|
590 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
591 |
+
xt = self.channel_upsampler_t(xt)
|
592 |
+
|
593 |
+
x, xt = self.crosstransformer(x, xt)
|
594 |
+
|
595 |
+
if self.bottom_channels:
|
596 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
597 |
+
x = self.channel_downsampler(x)
|
598 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
599 |
+
xt = self.channel_downsampler_t(xt)
|
600 |
+
|
601 |
+
for idx, decode in enumerate(self.decoder):
|
602 |
+
skip = saved.pop(-1)
|
603 |
+
x, pre = decode(x, skip, lengths.pop(-1))
|
604 |
+
# `pre` contains the output just before final transposed convolution,
|
605 |
+
# which is used when the freq. and time branch separate.
|
606 |
+
|
607 |
+
offset = self.depth - len(self.tdecoder)
|
608 |
+
if idx >= offset:
|
609 |
+
tdec = self.tdecoder[idx - offset]
|
610 |
+
length_t = lengths_t.pop(-1)
|
611 |
+
if tdec.empty:
|
612 |
+
assert pre.shape[2] == 1, pre.shape
|
613 |
+
pre = pre[:, :, 0]
|
614 |
+
xt, _ = tdec(pre, None, length_t)
|
615 |
+
else:
|
616 |
+
skip = saved_t.pop(-1)
|
617 |
+
xt, _ = tdec(xt, skip, length_t)
|
618 |
+
|
619 |
+
# Let's make sure we used all stored skip connections.
|
620 |
+
assert len(saved) == 0
|
621 |
+
assert len(lengths_t) == 0
|
622 |
+
assert len(saved_t) == 0
|
623 |
+
|
624 |
+
S = len(self.sources)
|
625 |
+
x = x.view(B, S, -1, Fq, T)
|
626 |
+
x = x * std[:, None] + mean[:, None]
|
627 |
+
|
628 |
+
# to cpu as non-cuda GPUs don't support complex numbers
|
629 |
+
# demucs issue #435 ##432
|
630 |
+
# NOTE: in this case z already is on cpu
|
631 |
+
# TODO: remove this when mps supports complex numbers
|
632 |
+
|
633 |
+
device_type = x.device.type
|
634 |
+
device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type
|
635 |
+
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
636 |
+
|
637 |
+
if x_is_other_gpu:
|
638 |
+
x = x.cpu()
|
639 |
+
|
640 |
+
zout = self._mask(z, x)
|
641 |
+
if self.use_train_segment:
|
642 |
+
if self.training:
|
643 |
+
x = self._ispec(zout, length)
|
644 |
+
else:
|
645 |
+
x = self._ispec(zout, training_length)
|
646 |
+
else:
|
647 |
+
x = self._ispec(zout, length)
|
648 |
+
|
649 |
+
# back to other device
|
650 |
+
if x_is_other_gpu:
|
651 |
+
x = x.to(device_load)
|
652 |
+
|
653 |
+
if self.use_train_segment:
|
654 |
+
if self.training:
|
655 |
+
xt = xt.view(B, S, -1, length)
|
656 |
+
else:
|
657 |
+
xt = xt.view(B, S, -1, training_length)
|
658 |
+
else:
|
659 |
+
xt = xt.view(B, S, -1, length)
|
660 |
+
xt = xt * stdt[:, None] + meant[:, None]
|
661 |
+
x = xt + x
|
662 |
+
if length_pre_pad:
|
663 |
+
x = x[..., :length_pre_pad]
|
664 |
+
return x
|
demucs/model.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch as th
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from .utils import capture_init, center_trim
|
13 |
+
|
14 |
+
|
15 |
+
class BLSTM(nn.Module):
|
16 |
+
def __init__(self, dim, layers=1):
|
17 |
+
super().__init__()
|
18 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
19 |
+
self.linear = nn.Linear(2 * dim, dim)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x.permute(2, 0, 1)
|
23 |
+
x = self.lstm(x)[0]
|
24 |
+
x = self.linear(x)
|
25 |
+
x = x.permute(1, 2, 0)
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
def rescale_conv(conv, reference):
|
30 |
+
std = conv.weight.std().detach()
|
31 |
+
scale = (std / reference)**0.5
|
32 |
+
conv.weight.data /= scale
|
33 |
+
if conv.bias is not None:
|
34 |
+
conv.bias.data /= scale
|
35 |
+
|
36 |
+
|
37 |
+
def rescale_module(module, reference):
|
38 |
+
for sub in module.modules():
|
39 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
40 |
+
rescale_conv(sub, reference)
|
41 |
+
|
42 |
+
|
43 |
+
def upsample(x, stride):
|
44 |
+
"""
|
45 |
+
Linear upsampling, the output will be `stride` times longer.
|
46 |
+
"""
|
47 |
+
batch, channels, time = x.size()
|
48 |
+
weight = th.arange(stride, device=x.device, dtype=th.float) / stride
|
49 |
+
x = x.view(batch, channels, time, 1)
|
50 |
+
out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight
|
51 |
+
return out.reshape(batch, channels, -1)
|
52 |
+
|
53 |
+
|
54 |
+
def downsample(x, stride):
|
55 |
+
"""
|
56 |
+
Downsample x by decimation.
|
57 |
+
"""
|
58 |
+
return x[:, :, ::stride]
|
59 |
+
|
60 |
+
|
61 |
+
class Demucs(nn.Module):
|
62 |
+
@capture_init
|
63 |
+
def __init__(self,
|
64 |
+
sources=4,
|
65 |
+
audio_channels=2,
|
66 |
+
channels=64,
|
67 |
+
depth=6,
|
68 |
+
rewrite=True,
|
69 |
+
glu=True,
|
70 |
+
upsample=False,
|
71 |
+
rescale=0.1,
|
72 |
+
kernel_size=8,
|
73 |
+
stride=4,
|
74 |
+
growth=2.,
|
75 |
+
lstm_layers=2,
|
76 |
+
context=3,
|
77 |
+
samplerate=44100):
|
78 |
+
"""
|
79 |
+
Args:
|
80 |
+
sources (int): number of sources to separate
|
81 |
+
audio_channels (int): stereo or mono
|
82 |
+
channels (int): first convolution channels
|
83 |
+
depth (int): number of encoder/decoder layers
|
84 |
+
rewrite (bool): add 1x1 convolution to each encoder layer
|
85 |
+
and a convolution to each decoder layer.
|
86 |
+
For the decoder layer, `context` gives the kernel size.
|
87 |
+
glu (bool): use glu instead of ReLU
|
88 |
+
upsample (bool): use linear upsampling with convolutions
|
89 |
+
Wave-U-Net style, instead of transposed convolutions
|
90 |
+
rescale (int): rescale initial weights of convolutions
|
91 |
+
to get their standard deviation closer to `rescale`
|
92 |
+
kernel_size (int): kernel size for convolutions
|
93 |
+
stride (int): stride for convolutions
|
94 |
+
growth (float): multiply (resp divide) number of channels by that
|
95 |
+
for each layer of the encoder (resp decoder)
|
96 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm
|
97 |
+
context (int): kernel size of the convolution in the
|
98 |
+
decoder before the transposed convolution. If > 1,
|
99 |
+
will provide some context from neighboring time
|
100 |
+
steps.
|
101 |
+
"""
|
102 |
+
|
103 |
+
super().__init__()
|
104 |
+
self.audio_channels = audio_channels
|
105 |
+
self.sources = sources
|
106 |
+
self.kernel_size = kernel_size
|
107 |
+
self.context = context
|
108 |
+
self.stride = stride
|
109 |
+
self.depth = depth
|
110 |
+
self.upsample = upsample
|
111 |
+
self.channels = channels
|
112 |
+
self.samplerate = samplerate
|
113 |
+
|
114 |
+
self.encoder = nn.ModuleList()
|
115 |
+
self.decoder = nn.ModuleList()
|
116 |
+
|
117 |
+
self.final = None
|
118 |
+
if upsample:
|
119 |
+
self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1)
|
120 |
+
stride = 1
|
121 |
+
|
122 |
+
if glu:
|
123 |
+
activation = nn.GLU(dim=1)
|
124 |
+
ch_scale = 2
|
125 |
+
else:
|
126 |
+
activation = nn.ReLU()
|
127 |
+
ch_scale = 1
|
128 |
+
in_channels = audio_channels
|
129 |
+
for index in range(depth):
|
130 |
+
encode = []
|
131 |
+
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
|
132 |
+
if rewrite:
|
133 |
+
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
|
134 |
+
self.encoder.append(nn.Sequential(*encode))
|
135 |
+
|
136 |
+
decode = []
|
137 |
+
if index > 0:
|
138 |
+
out_channels = in_channels
|
139 |
+
else:
|
140 |
+
if upsample:
|
141 |
+
out_channels = channels
|
142 |
+
else:
|
143 |
+
out_channels = sources * audio_channels
|
144 |
+
if rewrite:
|
145 |
+
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
|
146 |
+
if upsample:
|
147 |
+
decode += [
|
148 |
+
nn.Conv1d(channels, out_channels, kernel_size, stride=1),
|
149 |
+
]
|
150 |
+
else:
|
151 |
+
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
|
152 |
+
if index > 0:
|
153 |
+
decode.append(nn.ReLU())
|
154 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
155 |
+
in_channels = channels
|
156 |
+
channels = int(growth * channels)
|
157 |
+
|
158 |
+
channels = in_channels
|
159 |
+
|
160 |
+
if lstm_layers:
|
161 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
162 |
+
else:
|
163 |
+
self.lstm = None
|
164 |
+
|
165 |
+
if rescale:
|
166 |
+
rescale_module(self, reference=rescale)
|
167 |
+
|
168 |
+
def valid_length(self, length):
|
169 |
+
"""
|
170 |
+
Return the nearest valid length to use with the model so that
|
171 |
+
there is no time steps left over in a convolutions, e.g. for all
|
172 |
+
layers, size of the input - kernel_size % stride = 0.
|
173 |
+
|
174 |
+
If the mixture has a valid length, the estimated sources
|
175 |
+
will have exactly the same length when context = 1. If context > 1,
|
176 |
+
the two signals can be center trimmed to match.
|
177 |
+
|
178 |
+
For training, extracts should have a valid length.For evaluation
|
179 |
+
on full tracks we recommend passing `pad = True` to :method:`forward`.
|
180 |
+
"""
|
181 |
+
for _ in range(self.depth):
|
182 |
+
if self.upsample:
|
183 |
+
length = math.ceil(length / self.stride) + self.kernel_size - 1
|
184 |
+
else:
|
185 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
186 |
+
length = max(1, length)
|
187 |
+
length += self.context - 1
|
188 |
+
for _ in range(self.depth):
|
189 |
+
if self.upsample:
|
190 |
+
length = length * self.stride + self.kernel_size - 1
|
191 |
+
else:
|
192 |
+
length = (length - 1) * self.stride + self.kernel_size
|
193 |
+
|
194 |
+
return int(length)
|
195 |
+
|
196 |
+
def forward(self, mix):
|
197 |
+
x = mix
|
198 |
+
saved = [x]
|
199 |
+
for encode in self.encoder:
|
200 |
+
x = encode(x)
|
201 |
+
saved.append(x)
|
202 |
+
if self.upsample:
|
203 |
+
x = downsample(x, self.stride)
|
204 |
+
if self.lstm:
|
205 |
+
x = self.lstm(x)
|
206 |
+
for decode in self.decoder:
|
207 |
+
if self.upsample:
|
208 |
+
x = upsample(x, stride=self.stride)
|
209 |
+
skip = center_trim(saved.pop(-1), x)
|
210 |
+
x = x + skip
|
211 |
+
x = decode(x)
|
212 |
+
if self.final:
|
213 |
+
skip = center_trim(saved.pop(-1), x)
|
214 |
+
x = th.cat([x, skip], dim=1)
|
215 |
+
x = self.final(x)
|
216 |
+
|
217 |
+
x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1))
|
218 |
+
return x
|
demucs/model_v2.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import julius
|
10 |
+
from torch import nn
|
11 |
+
from .tasnet_v2 import ConvTasNet
|
12 |
+
|
13 |
+
from .utils import capture_init, center_trim
|
14 |
+
|
15 |
+
|
16 |
+
class BLSTM(nn.Module):
|
17 |
+
def __init__(self, dim, layers=1):
|
18 |
+
super().__init__()
|
19 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
20 |
+
self.linear = nn.Linear(2 * dim, dim)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
x = x.permute(2, 0, 1)
|
24 |
+
x = self.lstm(x)[0]
|
25 |
+
x = self.linear(x)
|
26 |
+
x = x.permute(1, 2, 0)
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
def rescale_conv(conv, reference):
|
31 |
+
std = conv.weight.std().detach()
|
32 |
+
scale = (std / reference)**0.5
|
33 |
+
conv.weight.data /= scale
|
34 |
+
if conv.bias is not None:
|
35 |
+
conv.bias.data /= scale
|
36 |
+
|
37 |
+
|
38 |
+
def rescale_module(module, reference):
|
39 |
+
for sub in module.modules():
|
40 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
41 |
+
rescale_conv(sub, reference)
|
42 |
+
|
43 |
+
def auto_load_demucs_model_v2(sources, demucs_model_name):
|
44 |
+
|
45 |
+
if '48' in demucs_model_name:
|
46 |
+
channels=48
|
47 |
+
elif 'unittest' in demucs_model_name:
|
48 |
+
channels=4
|
49 |
+
else:
|
50 |
+
channels=64
|
51 |
+
|
52 |
+
if 'tasnet' in demucs_model_name:
|
53 |
+
init_demucs_model = ConvTasNet(sources, X=10)
|
54 |
+
else:
|
55 |
+
init_demucs_model = Demucs(sources, channels=channels)
|
56 |
+
|
57 |
+
return init_demucs_model
|
58 |
+
|
59 |
+
class Demucs(nn.Module):
|
60 |
+
@capture_init
|
61 |
+
def __init__(self,
|
62 |
+
sources,
|
63 |
+
audio_channels=2,
|
64 |
+
channels=64,
|
65 |
+
depth=6,
|
66 |
+
rewrite=True,
|
67 |
+
glu=True,
|
68 |
+
rescale=0.1,
|
69 |
+
resample=True,
|
70 |
+
kernel_size=8,
|
71 |
+
stride=4,
|
72 |
+
growth=2.,
|
73 |
+
lstm_layers=2,
|
74 |
+
context=3,
|
75 |
+
normalize=False,
|
76 |
+
samplerate=44100,
|
77 |
+
segment_length=4 * 10 * 44100):
|
78 |
+
"""
|
79 |
+
Args:
|
80 |
+
sources (list[str]): list of source names
|
81 |
+
audio_channels (int): stereo or mono
|
82 |
+
channels (int): first convolution channels
|
83 |
+
depth (int): number of encoder/decoder layers
|
84 |
+
rewrite (bool): add 1x1 convolution to each encoder layer
|
85 |
+
and a convolution to each decoder layer.
|
86 |
+
For the decoder layer, `context` gives the kernel size.
|
87 |
+
glu (bool): use glu instead of ReLU
|
88 |
+
resample_input (bool): upsample x2 the input and downsample /2 the output.
|
89 |
+
rescale (int): rescale initial weights of convolutions
|
90 |
+
to get their standard deviation closer to `rescale`
|
91 |
+
kernel_size (int): kernel size for convolutions
|
92 |
+
stride (int): stride for convolutions
|
93 |
+
growth (float): multiply (resp divide) number of channels by that
|
94 |
+
for each layer of the encoder (resp decoder)
|
95 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm
|
96 |
+
context (int): kernel size of the convolution in the
|
97 |
+
decoder before the transposed convolution. If > 1,
|
98 |
+
will provide some context from neighboring time
|
99 |
+
steps.
|
100 |
+
samplerate (int): stored as meta information for easing
|
101 |
+
future evaluations of the model.
|
102 |
+
segment_length (int): stored as meta information for easing
|
103 |
+
future evaluations of the model. Length of the segments on which
|
104 |
+
the model was trained.
|
105 |
+
"""
|
106 |
+
|
107 |
+
super().__init__()
|
108 |
+
self.audio_channels = audio_channels
|
109 |
+
self.sources = sources
|
110 |
+
self.kernel_size = kernel_size
|
111 |
+
self.context = context
|
112 |
+
self.stride = stride
|
113 |
+
self.depth = depth
|
114 |
+
self.resample = resample
|
115 |
+
self.channels = channels
|
116 |
+
self.normalize = normalize
|
117 |
+
self.samplerate = samplerate
|
118 |
+
self.segment_length = segment_length
|
119 |
+
|
120 |
+
self.encoder = nn.ModuleList()
|
121 |
+
self.decoder = nn.ModuleList()
|
122 |
+
|
123 |
+
if glu:
|
124 |
+
activation = nn.GLU(dim=1)
|
125 |
+
ch_scale = 2
|
126 |
+
else:
|
127 |
+
activation = nn.ReLU()
|
128 |
+
ch_scale = 1
|
129 |
+
in_channels = audio_channels
|
130 |
+
for index in range(depth):
|
131 |
+
encode = []
|
132 |
+
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
|
133 |
+
if rewrite:
|
134 |
+
encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
|
135 |
+
self.encoder.append(nn.Sequential(*encode))
|
136 |
+
|
137 |
+
decode = []
|
138 |
+
if index > 0:
|
139 |
+
out_channels = in_channels
|
140 |
+
else:
|
141 |
+
out_channels = len(self.sources) * audio_channels
|
142 |
+
if rewrite:
|
143 |
+
decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
|
144 |
+
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
|
145 |
+
if index > 0:
|
146 |
+
decode.append(nn.ReLU())
|
147 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
148 |
+
in_channels = channels
|
149 |
+
channels = int(growth * channels)
|
150 |
+
|
151 |
+
channels = in_channels
|
152 |
+
|
153 |
+
if lstm_layers:
|
154 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
155 |
+
else:
|
156 |
+
self.lstm = None
|
157 |
+
|
158 |
+
if rescale:
|
159 |
+
rescale_module(self, reference=rescale)
|
160 |
+
|
161 |
+
def valid_length(self, length):
|
162 |
+
"""
|
163 |
+
Return the nearest valid length to use with the model so that
|
164 |
+
there is no time steps left over in a convolutions, e.g. for all
|
165 |
+
layers, size of the input - kernel_size % stride = 0.
|
166 |
+
|
167 |
+
If the mixture has a valid length, the estimated sources
|
168 |
+
will have exactly the same length when context = 1. If context > 1,
|
169 |
+
the two signals can be center trimmed to match.
|
170 |
+
|
171 |
+
For training, extracts should have a valid length.For evaluation
|
172 |
+
on full tracks we recommend passing `pad = True` to :method:`forward`.
|
173 |
+
"""
|
174 |
+
if self.resample:
|
175 |
+
length *= 2
|
176 |
+
for _ in range(self.depth):
|
177 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
178 |
+
length = max(1, length)
|
179 |
+
length += self.context - 1
|
180 |
+
for _ in range(self.depth):
|
181 |
+
length = (length - 1) * self.stride + self.kernel_size
|
182 |
+
|
183 |
+
if self.resample:
|
184 |
+
length = math.ceil(length / 2)
|
185 |
+
return int(length)
|
186 |
+
|
187 |
+
def forward(self, mix):
|
188 |
+
x = mix
|
189 |
+
|
190 |
+
if self.normalize:
|
191 |
+
mono = mix.mean(dim=1, keepdim=True)
|
192 |
+
mean = mono.mean(dim=-1, keepdim=True)
|
193 |
+
std = mono.std(dim=-1, keepdim=True)
|
194 |
+
else:
|
195 |
+
mean = 0
|
196 |
+
std = 1
|
197 |
+
|
198 |
+
x = (x - mean) / (1e-5 + std)
|
199 |
+
|
200 |
+
if self.resample:
|
201 |
+
x = julius.resample_frac(x, 1, 2)
|
202 |
+
|
203 |
+
saved = []
|
204 |
+
for encode in self.encoder:
|
205 |
+
x = encode(x)
|
206 |
+
saved.append(x)
|
207 |
+
if self.lstm:
|
208 |
+
x = self.lstm(x)
|
209 |
+
for decode in self.decoder:
|
210 |
+
skip = center_trim(saved.pop(-1), x)
|
211 |
+
x = x + skip
|
212 |
+
x = decode(x)
|
213 |
+
|
214 |
+
if self.resample:
|
215 |
+
x = julius.resample_frac(x, 2, 1)
|
216 |
+
x = x * std + mean
|
217 |
+
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
218 |
+
return x
|
demucs/pretrained.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Loading pretrained models.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
#from dora.log import fatal
|
14 |
+
|
15 |
+
import logging
|
16 |
+
|
17 |
+
from diffq import DiffQuantizer
|
18 |
+
import torch.hub
|
19 |
+
|
20 |
+
from .model import Demucs
|
21 |
+
from .tasnet_v2 import ConvTasNet
|
22 |
+
from .utils import set_state
|
23 |
+
|
24 |
+
from .hdemucs import HDemucs
|
25 |
+
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
|
29 |
+
REMOTE_ROOT = Path(__file__).parent / 'remote'
|
30 |
+
|
31 |
+
SOURCES = ["drums", "bass", "other", "vocals"]
|
32 |
+
|
33 |
+
|
34 |
+
def demucs_unittest():
|
35 |
+
model = HDemucs(channels=4, sources=SOURCES)
|
36 |
+
return model
|
37 |
+
|
38 |
+
|
39 |
+
def add_model_flags(parser):
|
40 |
+
group = parser.add_mutually_exclusive_group(required=False)
|
41 |
+
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
|
42 |
+
group.add_argument("-n", "--name", default="mdx_extra_q",
|
43 |
+
help="Pretrained model name or signature. Default is mdx_extra_q.")
|
44 |
+
parser.add_argument("--repo", type=Path,
|
45 |
+
help="Folder containing all pre-trained models for use with -n.")
|
46 |
+
|
47 |
+
|
48 |
+
def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
|
49 |
+
root: str = ''
|
50 |
+
models: tp.Dict[str, str] = {}
|
51 |
+
for line in remote_file_list.read_text().split('\n'):
|
52 |
+
line = line.strip()
|
53 |
+
if line.startswith('#'):
|
54 |
+
continue
|
55 |
+
elif line.startswith('root:'):
|
56 |
+
root = line.split(':', 1)[1].strip()
|
57 |
+
else:
|
58 |
+
sig = line.split('-', 1)[0]
|
59 |
+
assert sig not in models
|
60 |
+
models[sig] = ROOT_URL + root + line
|
61 |
+
return models
|
62 |
+
|
63 |
+
def get_model(name: str,
|
64 |
+
repo: tp.Optional[Path] = None):
|
65 |
+
"""`name` must be a bag of models name or a pretrained signature
|
66 |
+
from the remote AWS model repo or the specified local repo if `repo` is not None.
|
67 |
+
"""
|
68 |
+
if name == 'demucs_unittest':
|
69 |
+
return demucs_unittest()
|
70 |
+
model_repo: ModelOnlyRepo
|
71 |
+
if repo is None:
|
72 |
+
models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
|
73 |
+
model_repo = RemoteRepo(models)
|
74 |
+
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
|
75 |
+
else:
|
76 |
+
if not repo.is_dir():
|
77 |
+
fatal(f"{repo} must exist and be a directory.")
|
78 |
+
model_repo = LocalRepo(repo)
|
79 |
+
bag_repo = BagOnlyRepo(repo, model_repo)
|
80 |
+
any_repo = AnyModelRepo(model_repo, bag_repo)
|
81 |
+
model = any_repo.get_model(name)
|
82 |
+
model.eval()
|
83 |
+
return model
|
84 |
+
|
85 |
+
def get_model_from_args(args):
|
86 |
+
"""
|
87 |
+
Load local model package or pre-trained model.
|
88 |
+
"""
|
89 |
+
return get_model(name=args.name, repo=args.repo)
|
90 |
+
|
91 |
+
logger = logging.getLogger(__name__)
|
92 |
+
ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
|
93 |
+
|
94 |
+
PRETRAINED_MODELS = {
|
95 |
+
'demucs': 'e07c671f',
|
96 |
+
'demucs48_hq': '28a1282c',
|
97 |
+
'demucs_extra': '3646af93',
|
98 |
+
'demucs_quantized': '07afea75',
|
99 |
+
'tasnet': 'beb46fac',
|
100 |
+
'tasnet_extra': 'df3777b2',
|
101 |
+
'demucs_unittest': '09ebc15f',
|
102 |
+
}
|
103 |
+
|
104 |
+
SOURCES = ["drums", "bass", "other", "vocals"]
|
105 |
+
|
106 |
+
|
107 |
+
def get_url(name):
|
108 |
+
sig = PRETRAINED_MODELS[name]
|
109 |
+
return ROOT + name + "-" + sig[:8] + ".th"
|
110 |
+
|
111 |
+
def is_pretrained(name):
|
112 |
+
return name in PRETRAINED_MODELS
|
113 |
+
|
114 |
+
|
115 |
+
def load_pretrained(name):
|
116 |
+
if name == "demucs":
|
117 |
+
return demucs(pretrained=True)
|
118 |
+
elif name == "demucs48_hq":
|
119 |
+
return demucs(pretrained=True, hq=True, channels=48)
|
120 |
+
elif name == "demucs_extra":
|
121 |
+
return demucs(pretrained=True, extra=True)
|
122 |
+
elif name == "demucs_quantized":
|
123 |
+
return demucs(pretrained=True, quantized=True)
|
124 |
+
elif name == "demucs_unittest":
|
125 |
+
return demucs_unittest(pretrained=True)
|
126 |
+
elif name == "tasnet":
|
127 |
+
return tasnet(pretrained=True)
|
128 |
+
elif name == "tasnet_extra":
|
129 |
+
return tasnet(pretrained=True, extra=True)
|
130 |
+
else:
|
131 |
+
raise ValueError(f"Invalid pretrained name {name}")
|
132 |
+
|
133 |
+
|
134 |
+
def _load_state(name, model, quantizer=None):
|
135 |
+
url = get_url(name)
|
136 |
+
state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
|
137 |
+
set_state(model, quantizer, state)
|
138 |
+
if quantizer:
|
139 |
+
quantizer.detach()
|
140 |
+
|
141 |
+
|
142 |
+
def demucs_unittest(pretrained=True):
|
143 |
+
model = Demucs(channels=4, sources=SOURCES)
|
144 |
+
if pretrained:
|
145 |
+
_load_state('demucs_unittest', model)
|
146 |
+
return model
|
147 |
+
|
148 |
+
|
149 |
+
def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
|
150 |
+
if not pretrained and (extra or quantized or hq):
|
151 |
+
raise ValueError("if extra or quantized is True, pretrained must be True.")
|
152 |
+
model = Demucs(sources=SOURCES, channels=channels)
|
153 |
+
if pretrained:
|
154 |
+
name = 'demucs'
|
155 |
+
if channels != 64:
|
156 |
+
name += str(channels)
|
157 |
+
quantizer = None
|
158 |
+
if sum([extra, quantized, hq]) > 1:
|
159 |
+
raise ValueError("Only one of extra, quantized, hq, can be True.")
|
160 |
+
if quantized:
|
161 |
+
quantizer = DiffQuantizer(model, group_size=8, min_size=1)
|
162 |
+
name += '_quantized'
|
163 |
+
if extra:
|
164 |
+
name += '_extra'
|
165 |
+
if hq:
|
166 |
+
name += '_hq'
|
167 |
+
_load_state(name, model, quantizer)
|
168 |
+
return model
|
169 |
+
|
170 |
+
|
171 |
+
def tasnet(pretrained=True, extra=False):
|
172 |
+
if not pretrained and extra:
|
173 |
+
raise ValueError("if extra is True, pretrained must be True.")
|
174 |
+
model = ConvTasNet(X=10, sources=SOURCES)
|
175 |
+
if pretrained:
|
176 |
+
name = 'tasnet'
|
177 |
+
if extra:
|
178 |
+
name = 'tasnet_extra'
|
179 |
+
_load_state(name, model)
|
180 |
+
return model
|
demucs/repo.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Represents a model repository, including pre-trained models and bags of models.
|
7 |
+
A repo can either be the main remote repository stored in AWS, or a local repository
|
8 |
+
with your own models.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from hashlib import sha256
|
12 |
+
from pathlib import Path
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import yaml
|
17 |
+
|
18 |
+
from .apply import BagOfModels, Model
|
19 |
+
from .states import load_model
|
20 |
+
|
21 |
+
|
22 |
+
AnyModel = tp.Union[Model, BagOfModels]
|
23 |
+
|
24 |
+
|
25 |
+
class ModelLoadingError(RuntimeError):
|
26 |
+
pass
|
27 |
+
|
28 |
+
|
29 |
+
def check_checksum(path: Path, checksum: str):
|
30 |
+
sha = sha256()
|
31 |
+
with open(path, 'rb') as file:
|
32 |
+
while True:
|
33 |
+
buf = file.read(2**20)
|
34 |
+
if not buf:
|
35 |
+
break
|
36 |
+
sha.update(buf)
|
37 |
+
actual_checksum = sha.hexdigest()[:len(checksum)]
|
38 |
+
if actual_checksum != checksum:
|
39 |
+
raise ModelLoadingError(f'Invalid checksum for file {path}, '
|
40 |
+
f'expected {checksum} but got {actual_checksum}')
|
41 |
+
|
42 |
+
class ModelOnlyRepo:
|
43 |
+
"""Base class for all model only repos.
|
44 |
+
"""
|
45 |
+
def has_model(self, sig: str) -> bool:
|
46 |
+
raise NotImplementedError()
|
47 |
+
|
48 |
+
def get_model(self, sig: str) -> Model:
|
49 |
+
raise NotImplementedError()
|
50 |
+
|
51 |
+
|
52 |
+
class RemoteRepo(ModelOnlyRepo):
|
53 |
+
def __init__(self, models: tp.Dict[str, str]):
|
54 |
+
self._models = models
|
55 |
+
|
56 |
+
def has_model(self, sig: str) -> bool:
|
57 |
+
return sig in self._models
|
58 |
+
|
59 |
+
def get_model(self, sig: str) -> Model:
|
60 |
+
try:
|
61 |
+
url = self._models[sig]
|
62 |
+
except KeyError:
|
63 |
+
raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.')
|
64 |
+
pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
|
65 |
+
return load_model(pkg)
|
66 |
+
|
67 |
+
|
68 |
+
class LocalRepo(ModelOnlyRepo):
|
69 |
+
def __init__(self, root: Path):
|
70 |
+
self.root = root
|
71 |
+
self.scan()
|
72 |
+
|
73 |
+
def scan(self):
|
74 |
+
self._models = {}
|
75 |
+
self._checksums = {}
|
76 |
+
for file in self.root.iterdir():
|
77 |
+
if file.suffix == '.th':
|
78 |
+
if '-' in file.stem:
|
79 |
+
xp_sig, checksum = file.stem.split('-')
|
80 |
+
self._checksums[xp_sig] = checksum
|
81 |
+
else:
|
82 |
+
xp_sig = file.stem
|
83 |
+
if xp_sig in self._models:
|
84 |
+
print('Whats xp? ', xp_sig)
|
85 |
+
raise ModelLoadingError(
|
86 |
+
f'Duplicate pre-trained model exist for signature {xp_sig}. '
|
87 |
+
'Please delete all but one.')
|
88 |
+
self._models[xp_sig] = file
|
89 |
+
|
90 |
+
def has_model(self, sig: str) -> bool:
|
91 |
+
return sig in self._models
|
92 |
+
|
93 |
+
def get_model(self, sig: str) -> Model:
|
94 |
+
try:
|
95 |
+
file = self._models[sig]
|
96 |
+
except KeyError:
|
97 |
+
raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.')
|
98 |
+
if sig in self._checksums:
|
99 |
+
check_checksum(file, self._checksums[sig])
|
100 |
+
return load_model(file)
|
101 |
+
|
102 |
+
|
103 |
+
class BagOnlyRepo:
|
104 |
+
"""Handles only YAML files containing bag of models, leaving the actual
|
105 |
+
model loading to some Repo.
|
106 |
+
"""
|
107 |
+
def __init__(self, root: Path, model_repo: ModelOnlyRepo):
|
108 |
+
self.root = root
|
109 |
+
self.model_repo = model_repo
|
110 |
+
self.scan()
|
111 |
+
|
112 |
+
def scan(self):
|
113 |
+
self._bags = {}
|
114 |
+
for file in self.root.iterdir():
|
115 |
+
if file.suffix == '.yaml':
|
116 |
+
self._bags[file.stem] = file
|
117 |
+
|
118 |
+
def has_model(self, name: str) -> bool:
|
119 |
+
return name in self._bags
|
120 |
+
|
121 |
+
def get_model(self, name: str) -> BagOfModels:
|
122 |
+
try:
|
123 |
+
yaml_file = self._bags[name]
|
124 |
+
except KeyError:
|
125 |
+
raise ModelLoadingError(f'{name} is neither a single pre-trained model or '
|
126 |
+
'a bag of models.')
|
127 |
+
bag = yaml.safe_load(open(yaml_file))
|
128 |
+
signatures = bag['models']
|
129 |
+
models = [self.model_repo.get_model(sig) for sig in signatures]
|
130 |
+
weights = bag.get('weights')
|
131 |
+
segment = bag.get('segment')
|
132 |
+
return BagOfModels(models, weights, segment)
|
133 |
+
|
134 |
+
|
135 |
+
class AnyModelRepo:
|
136 |
+
def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
|
137 |
+
self.model_repo = model_repo
|
138 |
+
self.bag_repo = bag_repo
|
139 |
+
|
140 |
+
def has_model(self, name_or_sig: str) -> bool:
|
141 |
+
return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
|
142 |
+
|
143 |
+
def get_model(self, name_or_sig: str) -> AnyModel:
|
144 |
+
print('name_or_sig: ', name_or_sig)
|
145 |
+
if self.model_repo.has_model(name_or_sig):
|
146 |
+
return self.model_repo.get_model(name_or_sig)
|
147 |
+
else:
|
148 |
+
return self.bag_repo.get_model(name_or_sig)
|
demucs/spec.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Conveniance wrapper to perform STFT and iSTFT"""
|
7 |
+
|
8 |
+
import torch as th
|
9 |
+
|
10 |
+
|
11 |
+
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
12 |
+
*other, length = x.shape
|
13 |
+
x = x.reshape(-1, length)
|
14 |
+
|
15 |
+
device_type = x.device.type
|
16 |
+
is_other_gpu = not device_type in ["cuda", "cpu"]
|
17 |
+
|
18 |
+
if is_other_gpu:
|
19 |
+
x = x.cpu()
|
20 |
+
z = th.stft(x,
|
21 |
+
n_fft * (1 + pad),
|
22 |
+
hop_length or n_fft // 4,
|
23 |
+
window=th.hann_window(n_fft).to(x),
|
24 |
+
win_length=n_fft,
|
25 |
+
normalized=True,
|
26 |
+
center=True,
|
27 |
+
return_complex=True,
|
28 |
+
pad_mode='reflect')
|
29 |
+
_, freqs, frame = z.shape
|
30 |
+
return z.view(*other, freqs, frame)
|
31 |
+
|
32 |
+
|
33 |
+
def ispectro(z, hop_length=None, length=None, pad=0):
|
34 |
+
*other, freqs, frames = z.shape
|
35 |
+
n_fft = 2 * freqs - 2
|
36 |
+
z = z.view(-1, freqs, frames)
|
37 |
+
win_length = n_fft // (1 + pad)
|
38 |
+
|
39 |
+
device_type = z.device.type
|
40 |
+
is_other_gpu = not device_type in ["cuda", "cpu"]
|
41 |
+
|
42 |
+
if is_other_gpu:
|
43 |
+
z = z.cpu()
|
44 |
+
x = th.istft(z,
|
45 |
+
n_fft,
|
46 |
+
hop_length,
|
47 |
+
window=th.hann_window(win_length).to(z.real),
|
48 |
+
win_length=win_length,
|
49 |
+
normalized=True,
|
50 |
+
length=length,
|
51 |
+
center=True)
|
52 |
+
_, length = x.shape
|
53 |
+
return x.view(*other, length)
|
demucs/states.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Utilities to save and load models.
|
8 |
+
"""
|
9 |
+
from contextlib import contextmanager
|
10 |
+
|
11 |
+
import functools
|
12 |
+
import hashlib
|
13 |
+
import inspect
|
14 |
+
import io
|
15 |
+
from pathlib import Path
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
from omegaconf import OmegaConf
|
19 |
+
from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
|
20 |
+
import torch
|
21 |
+
|
22 |
+
|
23 |
+
def get_quantizer(model, args, optimizer=None):
|
24 |
+
"""Return the quantizer given the XP quantization args."""
|
25 |
+
quantizer = None
|
26 |
+
if args.diffq:
|
27 |
+
quantizer = DiffQuantizer(
|
28 |
+
model, min_size=args.min_size, group_size=args.group_size)
|
29 |
+
if optimizer is not None:
|
30 |
+
quantizer.setup_optimizer(optimizer)
|
31 |
+
elif args.qat:
|
32 |
+
quantizer = UniformQuantizer(
|
33 |
+
model, bits=args.qat, min_size=args.min_size)
|
34 |
+
return quantizer
|
35 |
+
|
36 |
+
|
37 |
+
def load_model(path_or_package, strict=False):
|
38 |
+
"""Load a model from the given serialized model, either given as a dict (already loaded)
|
39 |
+
or a path to a file on disk."""
|
40 |
+
if isinstance(path_or_package, dict):
|
41 |
+
package = path_or_package
|
42 |
+
elif isinstance(path_or_package, (str, Path)):
|
43 |
+
with warnings.catch_warnings():
|
44 |
+
warnings.simplefilter("ignore")
|
45 |
+
path = path_or_package
|
46 |
+
package = torch.load(path, 'cpu')
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Invalid type for {path_or_package}.")
|
49 |
+
|
50 |
+
klass = package["klass"]
|
51 |
+
args = package["args"]
|
52 |
+
kwargs = package["kwargs"]
|
53 |
+
|
54 |
+
if strict:
|
55 |
+
model = klass(*args, **kwargs)
|
56 |
+
else:
|
57 |
+
sig = inspect.signature(klass)
|
58 |
+
for key in list(kwargs):
|
59 |
+
if key not in sig.parameters:
|
60 |
+
warnings.warn("Dropping inexistant parameter " + key)
|
61 |
+
del kwargs[key]
|
62 |
+
model = klass(*args, **kwargs)
|
63 |
+
|
64 |
+
state = package["state"]
|
65 |
+
|
66 |
+
set_state(model, state)
|
67 |
+
return model
|
68 |
+
|
69 |
+
|
70 |
+
def get_state(model, quantizer, half=False):
|
71 |
+
"""Get the state from a model, potentially with quantization applied.
|
72 |
+
If `half` is True, model are stored as half precision, which shouldn't impact performance
|
73 |
+
but half the state size."""
|
74 |
+
if quantizer is None:
|
75 |
+
dtype = torch.half if half else None
|
76 |
+
state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
|
77 |
+
else:
|
78 |
+
state = quantizer.get_quantized_state()
|
79 |
+
state['__quantized'] = True
|
80 |
+
return state
|
81 |
+
|
82 |
+
|
83 |
+
def set_state(model, state, quantizer=None):
|
84 |
+
"""Set the state on a given model."""
|
85 |
+
if state.get('__quantized'):
|
86 |
+
if quantizer is not None:
|
87 |
+
quantizer.restore_quantized_state(model, state['quantized'])
|
88 |
+
else:
|
89 |
+
restore_quantized_state(model, state)
|
90 |
+
else:
|
91 |
+
model.load_state_dict(state)
|
92 |
+
return state
|
93 |
+
|
94 |
+
|
95 |
+
def save_with_checksum(content, path):
|
96 |
+
"""Save the given value on disk, along with a sha256 hash.
|
97 |
+
Should be used with the output of either `serialize_model` or `get_state`."""
|
98 |
+
buf = io.BytesIO()
|
99 |
+
torch.save(content, buf)
|
100 |
+
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
|
101 |
+
|
102 |
+
path = path.parent / (path.stem + "-" + sig + path.suffix)
|
103 |
+
path.write_bytes(buf.getvalue())
|
104 |
+
|
105 |
+
|
106 |
+
def serialize_model(model, training_args, quantizer=None, half=True):
|
107 |
+
args, kwargs = model._init_args_kwargs
|
108 |
+
klass = model.__class__
|
109 |
+
|
110 |
+
state = get_state(model, quantizer, half)
|
111 |
+
return {
|
112 |
+
'klass': klass,
|
113 |
+
'args': args,
|
114 |
+
'kwargs': kwargs,
|
115 |
+
'state': state,
|
116 |
+
'training_args': OmegaConf.to_container(training_args, resolve=True),
|
117 |
+
}
|
118 |
+
|
119 |
+
|
120 |
+
def copy_state(state):
|
121 |
+
return {k: v.cpu().clone() for k, v in state.items()}
|
122 |
+
|
123 |
+
|
124 |
+
@contextmanager
|
125 |
+
def swap_state(model, state):
|
126 |
+
"""
|
127 |
+
Context manager that swaps the state of a model, e.g:
|
128 |
+
|
129 |
+
# model is in old state
|
130 |
+
with swap_state(model, new_state):
|
131 |
+
# model in new state
|
132 |
+
# model back to old state
|
133 |
+
"""
|
134 |
+
old_state = copy_state(model.state_dict())
|
135 |
+
model.load_state_dict(state, strict=False)
|
136 |
+
try:
|
137 |
+
yield
|
138 |
+
finally:
|
139 |
+
model.load_state_dict(old_state)
|
140 |
+
|
141 |
+
|
142 |
+
def capture_init(init):
|
143 |
+
@functools.wraps(init)
|
144 |
+
def __init__(self, *args, **kwargs):
|
145 |
+
self._init_args_kwargs = (args, kwargs)
|
146 |
+
init(self, *args, **kwargs)
|
147 |
+
|
148 |
+
return __init__
|
demucs/tasnet.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Created on 2018/12
|
8 |
+
# Author: Kaituo XU
|
9 |
+
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
|
10 |
+
# Here is the original license:
|
11 |
+
# The MIT License (MIT)
|
12 |
+
#
|
13 |
+
# Copyright (c) 2018 Kaituo XU
|
14 |
+
#
|
15 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
16 |
+
# of this software and associated documentation files (the "Software"), to deal
|
17 |
+
# in the Software without restriction, including without limitation the rights
|
18 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
19 |
+
# copies of the Software, and to permit persons to whom the Software is
|
20 |
+
# furnished to do so, subject to the following conditions:
|
21 |
+
#
|
22 |
+
# The above copyright notice and this permission notice shall be included in all
|
23 |
+
# copies or substantial portions of the Software.
|
24 |
+
#
|
25 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
26 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
27 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
28 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
29 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
30 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
31 |
+
# SOFTWARE.
|
32 |
+
|
33 |
+
import math
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import torch.nn as nn
|
37 |
+
import torch.nn.functional as F
|
38 |
+
|
39 |
+
from .utils import capture_init
|
40 |
+
|
41 |
+
EPS = 1e-8
|
42 |
+
|
43 |
+
|
44 |
+
def overlap_and_add(signal, frame_step):
|
45 |
+
outer_dimensions = signal.size()[:-2]
|
46 |
+
frames, frame_length = signal.size()[-2:]
|
47 |
+
|
48 |
+
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
49 |
+
subframe_step = frame_step // subframe_length
|
50 |
+
subframes_per_frame = frame_length // subframe_length
|
51 |
+
output_size = frame_step * (frames - 1) + frame_length
|
52 |
+
output_subframes = output_size // subframe_length
|
53 |
+
|
54 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
55 |
+
|
56 |
+
frame = torch.arange(0, output_subframes,
|
57 |
+
device=signal.device).unfold(0, subframes_per_frame, subframe_step)
|
58 |
+
frame = frame.long() # signal may in GPU or CPU
|
59 |
+
frame = frame.contiguous().view(-1)
|
60 |
+
|
61 |
+
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
62 |
+
result.index_add_(-2, frame, subframe_signal)
|
63 |
+
result = result.view(*outer_dimensions, -1)
|
64 |
+
return result
|
65 |
+
|
66 |
+
|
67 |
+
class ConvTasNet(nn.Module):
|
68 |
+
@capture_init
|
69 |
+
def __init__(self,
|
70 |
+
N=256,
|
71 |
+
L=20,
|
72 |
+
B=256,
|
73 |
+
H=512,
|
74 |
+
P=3,
|
75 |
+
X=8,
|
76 |
+
R=4,
|
77 |
+
C=4,
|
78 |
+
audio_channels=1,
|
79 |
+
samplerate=44100,
|
80 |
+
norm_type="gLN",
|
81 |
+
causal=False,
|
82 |
+
mask_nonlinear='relu'):
|
83 |
+
"""
|
84 |
+
Args:
|
85 |
+
N: Number of filters in autoencoder
|
86 |
+
L: Length of the filters (in samples)
|
87 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
88 |
+
H: Number of channels in convolutional blocks
|
89 |
+
P: Kernel size in convolutional blocks
|
90 |
+
X: Number of convolutional blocks in each repeat
|
91 |
+
R: Number of repeats
|
92 |
+
C: Number of speakers
|
93 |
+
norm_type: BN, gLN, cLN
|
94 |
+
causal: causal or non-causal
|
95 |
+
mask_nonlinear: use which non-linear function to generate mask
|
96 |
+
"""
|
97 |
+
super(ConvTasNet, self).__init__()
|
98 |
+
# Hyper-parameter
|
99 |
+
self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
|
100 |
+
self.norm_type = norm_type
|
101 |
+
self.causal = causal
|
102 |
+
self.mask_nonlinear = mask_nonlinear
|
103 |
+
self.audio_channels = audio_channels
|
104 |
+
self.samplerate = samplerate
|
105 |
+
# Components
|
106 |
+
self.encoder = Encoder(L, N, audio_channels)
|
107 |
+
self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
|
108 |
+
self.decoder = Decoder(N, L, audio_channels)
|
109 |
+
# init
|
110 |
+
for p in self.parameters():
|
111 |
+
if p.dim() > 1:
|
112 |
+
nn.init.xavier_normal_(p)
|
113 |
+
|
114 |
+
def valid_length(self, length):
|
115 |
+
return length
|
116 |
+
|
117 |
+
def forward(self, mixture):
|
118 |
+
"""
|
119 |
+
Args:
|
120 |
+
mixture: [M, T], M is batch size, T is #samples
|
121 |
+
Returns:
|
122 |
+
est_source: [M, C, T]
|
123 |
+
"""
|
124 |
+
mixture_w = self.encoder(mixture)
|
125 |
+
est_mask = self.separator(mixture_w)
|
126 |
+
est_source = self.decoder(mixture_w, est_mask)
|
127 |
+
|
128 |
+
# T changed after conv1d in encoder, fix it here
|
129 |
+
T_origin = mixture.size(-1)
|
130 |
+
T_conv = est_source.size(-1)
|
131 |
+
est_source = F.pad(est_source, (0, T_origin - T_conv))
|
132 |
+
return est_source
|
133 |
+
|
134 |
+
|
135 |
+
class Encoder(nn.Module):
|
136 |
+
"""Estimation of the nonnegative mixture weight by a 1-D conv layer.
|
137 |
+
"""
|
138 |
+
def __init__(self, L, N, audio_channels):
|
139 |
+
super(Encoder, self).__init__()
|
140 |
+
# Hyper-parameter
|
141 |
+
self.L, self.N = L, N
|
142 |
+
# Components
|
143 |
+
# 50% overlap
|
144 |
+
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
|
145 |
+
|
146 |
+
def forward(self, mixture):
|
147 |
+
"""
|
148 |
+
Args:
|
149 |
+
mixture: [M, T], M is batch size, T is #samples
|
150 |
+
Returns:
|
151 |
+
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
|
152 |
+
"""
|
153 |
+
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
|
154 |
+
return mixture_w
|
155 |
+
|
156 |
+
|
157 |
+
class Decoder(nn.Module):
|
158 |
+
def __init__(self, N, L, audio_channels):
|
159 |
+
super(Decoder, self).__init__()
|
160 |
+
# Hyper-parameter
|
161 |
+
self.N, self.L = N, L
|
162 |
+
self.audio_channels = audio_channels
|
163 |
+
# Components
|
164 |
+
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
|
165 |
+
|
166 |
+
def forward(self, mixture_w, est_mask):
|
167 |
+
"""
|
168 |
+
Args:
|
169 |
+
mixture_w: [M, N, K]
|
170 |
+
est_mask: [M, C, N, K]
|
171 |
+
Returns:
|
172 |
+
est_source: [M, C, T]
|
173 |
+
"""
|
174 |
+
# D = W * M
|
175 |
+
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
|
176 |
+
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
|
177 |
+
# S = DV
|
178 |
+
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
|
179 |
+
m, c, k, _ = est_source.size()
|
180 |
+
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
|
181 |
+
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
|
182 |
+
return est_source
|
183 |
+
|
184 |
+
|
185 |
+
class TemporalConvNet(nn.Module):
|
186 |
+
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
|
187 |
+
"""
|
188 |
+
Args:
|
189 |
+
N: Number of filters in autoencoder
|
190 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
191 |
+
H: Number of channels in convolutional blocks
|
192 |
+
P: Kernel size in convolutional blocks
|
193 |
+
X: Number of convolutional blocks in each repeat
|
194 |
+
R: Number of repeats
|
195 |
+
C: Number of speakers
|
196 |
+
norm_type: BN, gLN, cLN
|
197 |
+
causal: causal or non-causal
|
198 |
+
mask_nonlinear: use which non-linear function to generate mask
|
199 |
+
"""
|
200 |
+
super(TemporalConvNet, self).__init__()
|
201 |
+
# Hyper-parameter
|
202 |
+
self.C = C
|
203 |
+
self.mask_nonlinear = mask_nonlinear
|
204 |
+
# Components
|
205 |
+
# [M, N, K] -> [M, N, K]
|
206 |
+
layer_norm = ChannelwiseLayerNorm(N)
|
207 |
+
# [M, N, K] -> [M, B, K]
|
208 |
+
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
|
209 |
+
# [M, B, K] -> [M, B, K]
|
210 |
+
repeats = []
|
211 |
+
for r in range(R):
|
212 |
+
blocks = []
|
213 |
+
for x in range(X):
|
214 |
+
dilation = 2**x
|
215 |
+
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
|
216 |
+
blocks += [
|
217 |
+
TemporalBlock(B,
|
218 |
+
H,
|
219 |
+
P,
|
220 |
+
stride=1,
|
221 |
+
padding=padding,
|
222 |
+
dilation=dilation,
|
223 |
+
norm_type=norm_type,
|
224 |
+
causal=causal)
|
225 |
+
]
|
226 |
+
repeats += [nn.Sequential(*blocks)]
|
227 |
+
temporal_conv_net = nn.Sequential(*repeats)
|
228 |
+
# [M, B, K] -> [M, C*N, K]
|
229 |
+
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
|
230 |
+
# Put together
|
231 |
+
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
|
232 |
+
mask_conv1x1)
|
233 |
+
|
234 |
+
def forward(self, mixture_w):
|
235 |
+
"""
|
236 |
+
Keep this API same with TasNet
|
237 |
+
Args:
|
238 |
+
mixture_w: [M, N, K], M is batch size
|
239 |
+
returns:
|
240 |
+
est_mask: [M, C, N, K]
|
241 |
+
"""
|
242 |
+
M, N, K = mixture_w.size()
|
243 |
+
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
|
244 |
+
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
|
245 |
+
if self.mask_nonlinear == 'softmax':
|
246 |
+
est_mask = F.softmax(score, dim=1)
|
247 |
+
elif self.mask_nonlinear == 'relu':
|
248 |
+
est_mask = F.relu(score)
|
249 |
+
else:
|
250 |
+
raise ValueError("Unsupported mask non-linear function")
|
251 |
+
return est_mask
|
252 |
+
|
253 |
+
|
254 |
+
class TemporalBlock(nn.Module):
|
255 |
+
def __init__(self,
|
256 |
+
in_channels,
|
257 |
+
out_channels,
|
258 |
+
kernel_size,
|
259 |
+
stride,
|
260 |
+
padding,
|
261 |
+
dilation,
|
262 |
+
norm_type="gLN",
|
263 |
+
causal=False):
|
264 |
+
super(TemporalBlock, self).__init__()
|
265 |
+
# [M, B, K] -> [M, H, K]
|
266 |
+
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
267 |
+
prelu = nn.PReLU()
|
268 |
+
norm = chose_norm(norm_type, out_channels)
|
269 |
+
# [M, H, K] -> [M, B, K]
|
270 |
+
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
|
271 |
+
dilation, norm_type, causal)
|
272 |
+
# Put together
|
273 |
+
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
"""
|
277 |
+
Args:
|
278 |
+
x: [M, B, K]
|
279 |
+
Returns:
|
280 |
+
[M, B, K]
|
281 |
+
"""
|
282 |
+
residual = x
|
283 |
+
out = self.net(x)
|
284 |
+
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
|
285 |
+
return out + residual # look like w/o F.relu is better than w/ F.relu
|
286 |
+
# return F.relu(out + residual)
|
287 |
+
|
288 |
+
|
289 |
+
class DepthwiseSeparableConv(nn.Module):
|
290 |
+
def __init__(self,
|
291 |
+
in_channels,
|
292 |
+
out_channels,
|
293 |
+
kernel_size,
|
294 |
+
stride,
|
295 |
+
padding,
|
296 |
+
dilation,
|
297 |
+
norm_type="gLN",
|
298 |
+
causal=False):
|
299 |
+
super(DepthwiseSeparableConv, self).__init__()
|
300 |
+
# Use `groups` option to implement depthwise convolution
|
301 |
+
# [M, H, K] -> [M, H, K]
|
302 |
+
depthwise_conv = nn.Conv1d(in_channels,
|
303 |
+
in_channels,
|
304 |
+
kernel_size,
|
305 |
+
stride=stride,
|
306 |
+
padding=padding,
|
307 |
+
dilation=dilation,
|
308 |
+
groups=in_channels,
|
309 |
+
bias=False)
|
310 |
+
if causal:
|
311 |
+
chomp = Chomp1d(padding)
|
312 |
+
prelu = nn.PReLU()
|
313 |
+
norm = chose_norm(norm_type, in_channels)
|
314 |
+
# [M, H, K] -> [M, B, K]
|
315 |
+
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
316 |
+
# Put together
|
317 |
+
if causal:
|
318 |
+
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
|
319 |
+
else:
|
320 |
+
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
|
321 |
+
|
322 |
+
def forward(self, x):
|
323 |
+
"""
|
324 |
+
Args:
|
325 |
+
x: [M, H, K]
|
326 |
+
Returns:
|
327 |
+
result: [M, B, K]
|
328 |
+
"""
|
329 |
+
return self.net(x)
|
330 |
+
|
331 |
+
|
332 |
+
class Chomp1d(nn.Module):
|
333 |
+
"""To ensure the output length is the same as the input.
|
334 |
+
"""
|
335 |
+
def __init__(self, chomp_size):
|
336 |
+
super(Chomp1d, self).__init__()
|
337 |
+
self.chomp_size = chomp_size
|
338 |
+
|
339 |
+
def forward(self, x):
|
340 |
+
"""
|
341 |
+
Args:
|
342 |
+
x: [M, H, Kpad]
|
343 |
+
Returns:
|
344 |
+
[M, H, K]
|
345 |
+
"""
|
346 |
+
return x[:, :, :-self.chomp_size].contiguous()
|
347 |
+
|
348 |
+
|
349 |
+
def chose_norm(norm_type, channel_size):
|
350 |
+
"""The input of normlization will be (M, C, K), where M is batch size,
|
351 |
+
C is channel size and K is sequence length.
|
352 |
+
"""
|
353 |
+
if norm_type == "gLN":
|
354 |
+
return GlobalLayerNorm(channel_size)
|
355 |
+
elif norm_type == "cLN":
|
356 |
+
return ChannelwiseLayerNorm(channel_size)
|
357 |
+
elif norm_type == "id":
|
358 |
+
return nn.Identity()
|
359 |
+
else: # norm_type == "BN":
|
360 |
+
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
361 |
+
# along M and K, so this BN usage is right.
|
362 |
+
return nn.BatchNorm1d(channel_size)
|
363 |
+
|
364 |
+
|
365 |
+
# TODO: Use nn.LayerNorm to impl cLN to speed up
|
366 |
+
class ChannelwiseLayerNorm(nn.Module):
|
367 |
+
"""Channel-wise Layer Normalization (cLN)"""
|
368 |
+
def __init__(self, channel_size):
|
369 |
+
super(ChannelwiseLayerNorm, self).__init__()
|
370 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
371 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
372 |
+
self.reset_parameters()
|
373 |
+
|
374 |
+
def reset_parameters(self):
|
375 |
+
self.gamma.data.fill_(1)
|
376 |
+
self.beta.data.zero_()
|
377 |
+
|
378 |
+
def forward(self, y):
|
379 |
+
"""
|
380 |
+
Args:
|
381 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
382 |
+
Returns:
|
383 |
+
cLN_y: [M, N, K]
|
384 |
+
"""
|
385 |
+
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
|
386 |
+
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
|
387 |
+
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
388 |
+
return cLN_y
|
389 |
+
|
390 |
+
|
391 |
+
class GlobalLayerNorm(nn.Module):
|
392 |
+
"""Global Layer Normalization (gLN)"""
|
393 |
+
def __init__(self, channel_size):
|
394 |
+
super(GlobalLayerNorm, self).__init__()
|
395 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
396 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
397 |
+
self.reset_parameters()
|
398 |
+
|
399 |
+
def reset_parameters(self):
|
400 |
+
self.gamma.data.fill_(1)
|
401 |
+
self.beta.data.zero_()
|
402 |
+
|
403 |
+
def forward(self, y):
|
404 |
+
"""
|
405 |
+
Args:
|
406 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
407 |
+
Returns:
|
408 |
+
gLN_y: [M, N, K]
|
409 |
+
"""
|
410 |
+
# TODO: in torch 1.0, torch.mean() support dim list
|
411 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
|
412 |
+
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
413 |
+
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
414 |
+
return gLN_y
|
415 |
+
|
416 |
+
|
417 |
+
if __name__ == "__main__":
|
418 |
+
torch.manual_seed(123)
|
419 |
+
M, N, L, T = 2, 3, 4, 12
|
420 |
+
K = 2 * T // L - 1
|
421 |
+
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
|
422 |
+
mixture = torch.randint(3, (M, T))
|
423 |
+
# test Encoder
|
424 |
+
encoder = Encoder(L, N)
|
425 |
+
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
|
426 |
+
mixture_w = encoder(mixture)
|
427 |
+
print('mixture', mixture)
|
428 |
+
print('U', encoder.conv1d_U.weight)
|
429 |
+
print('mixture_w', mixture_w)
|
430 |
+
print('mixture_w size', mixture_w.size())
|
431 |
+
|
432 |
+
# test TemporalConvNet
|
433 |
+
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
|
434 |
+
est_mask = separator(mixture_w)
|
435 |
+
print('est_mask', est_mask)
|
436 |
+
|
437 |
+
# test Decoder
|
438 |
+
decoder = Decoder(N, L)
|
439 |
+
est_mask = torch.randint(2, (B, K, C, N))
|
440 |
+
est_source = decoder(mixture_w, est_mask)
|
441 |
+
print('est_source', est_source)
|
442 |
+
|
443 |
+
# test Conv-TasNet
|
444 |
+
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
|
445 |
+
est_source = conv_tasnet(mixture)
|
446 |
+
print('est_source', est_source)
|
447 |
+
print('est_source size', est_source.size())
|
demucs/tasnet_v2.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Created on 2018/12
|
8 |
+
# Author: Kaituo XU
|
9 |
+
# Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
|
10 |
+
# Here is the original license:
|
11 |
+
# The MIT License (MIT)
|
12 |
+
#
|
13 |
+
# Copyright (c) 2018 Kaituo XU
|
14 |
+
#
|
15 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
16 |
+
# of this software and associated documentation files (the "Software"), to deal
|
17 |
+
# in the Software without restriction, including without limitation the rights
|
18 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
19 |
+
# copies of the Software, and to permit persons to whom the Software is
|
20 |
+
# furnished to do so, subject to the following conditions:
|
21 |
+
#
|
22 |
+
# The above copyright notice and this permission notice shall be included in all
|
23 |
+
# copies or substantial portions of the Software.
|
24 |
+
#
|
25 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
26 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
27 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
28 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
29 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
30 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
31 |
+
# SOFTWARE.
|
32 |
+
|
33 |
+
import math
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import torch.nn as nn
|
37 |
+
import torch.nn.functional as F
|
38 |
+
|
39 |
+
from .utils import capture_init
|
40 |
+
|
41 |
+
EPS = 1e-8
|
42 |
+
|
43 |
+
|
44 |
+
def overlap_and_add(signal, frame_step):
|
45 |
+
outer_dimensions = signal.size()[:-2]
|
46 |
+
frames, frame_length = signal.size()[-2:]
|
47 |
+
|
48 |
+
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
49 |
+
subframe_step = frame_step // subframe_length
|
50 |
+
subframes_per_frame = frame_length // subframe_length
|
51 |
+
output_size = frame_step * (frames - 1) + frame_length
|
52 |
+
output_subframes = output_size // subframe_length
|
53 |
+
|
54 |
+
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
55 |
+
|
56 |
+
frame = torch.arange(0, output_subframes,
|
57 |
+
device=signal.device).unfold(0, subframes_per_frame, subframe_step)
|
58 |
+
frame = frame.long() # signal may in GPU or CPU
|
59 |
+
frame = frame.contiguous().view(-1)
|
60 |
+
|
61 |
+
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
62 |
+
result.index_add_(-2, frame, subframe_signal)
|
63 |
+
result = result.view(*outer_dimensions, -1)
|
64 |
+
return result
|
65 |
+
|
66 |
+
|
67 |
+
class ConvTasNet(nn.Module):
|
68 |
+
@capture_init
|
69 |
+
def __init__(self,
|
70 |
+
sources,
|
71 |
+
N=256,
|
72 |
+
L=20,
|
73 |
+
B=256,
|
74 |
+
H=512,
|
75 |
+
P=3,
|
76 |
+
X=8,
|
77 |
+
R=4,
|
78 |
+
audio_channels=2,
|
79 |
+
norm_type="gLN",
|
80 |
+
causal=False,
|
81 |
+
mask_nonlinear='relu',
|
82 |
+
samplerate=44100,
|
83 |
+
segment_length=44100 * 2 * 4):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
sources: list of sources
|
87 |
+
N: Number of filters in autoencoder
|
88 |
+
L: Length of the filters (in samples)
|
89 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
90 |
+
H: Number of channels in convolutional blocks
|
91 |
+
P: Kernel size in convolutional blocks
|
92 |
+
X: Number of convolutional blocks in each repeat
|
93 |
+
R: Number of repeats
|
94 |
+
norm_type: BN, gLN, cLN
|
95 |
+
causal: causal or non-causal
|
96 |
+
mask_nonlinear: use which non-linear function to generate mask
|
97 |
+
"""
|
98 |
+
super(ConvTasNet, self).__init__()
|
99 |
+
# Hyper-parameter
|
100 |
+
self.sources = sources
|
101 |
+
self.C = len(sources)
|
102 |
+
self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
|
103 |
+
self.norm_type = norm_type
|
104 |
+
self.causal = causal
|
105 |
+
self.mask_nonlinear = mask_nonlinear
|
106 |
+
self.audio_channels = audio_channels
|
107 |
+
self.samplerate = samplerate
|
108 |
+
self.segment_length = segment_length
|
109 |
+
# Components
|
110 |
+
self.encoder = Encoder(L, N, audio_channels)
|
111 |
+
self.separator = TemporalConvNet(
|
112 |
+
N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
|
113 |
+
self.decoder = Decoder(N, L, audio_channels)
|
114 |
+
# init
|
115 |
+
for p in self.parameters():
|
116 |
+
if p.dim() > 1:
|
117 |
+
nn.init.xavier_normal_(p)
|
118 |
+
|
119 |
+
def valid_length(self, length):
|
120 |
+
return length
|
121 |
+
|
122 |
+
def forward(self, mixture):
|
123 |
+
"""
|
124 |
+
Args:
|
125 |
+
mixture: [M, T], M is batch size, T is #samples
|
126 |
+
Returns:
|
127 |
+
est_source: [M, C, T]
|
128 |
+
"""
|
129 |
+
mixture_w = self.encoder(mixture)
|
130 |
+
est_mask = self.separator(mixture_w)
|
131 |
+
est_source = self.decoder(mixture_w, est_mask)
|
132 |
+
|
133 |
+
# T changed after conv1d in encoder, fix it here
|
134 |
+
T_origin = mixture.size(-1)
|
135 |
+
T_conv = est_source.size(-1)
|
136 |
+
est_source = F.pad(est_source, (0, T_origin - T_conv))
|
137 |
+
return est_source
|
138 |
+
|
139 |
+
|
140 |
+
class Encoder(nn.Module):
|
141 |
+
"""Estimation of the nonnegative mixture weight by a 1-D conv layer.
|
142 |
+
"""
|
143 |
+
def __init__(self, L, N, audio_channels):
|
144 |
+
super(Encoder, self).__init__()
|
145 |
+
# Hyper-parameter
|
146 |
+
self.L, self.N = L, N
|
147 |
+
# Components
|
148 |
+
# 50% overlap
|
149 |
+
self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
|
150 |
+
|
151 |
+
def forward(self, mixture):
|
152 |
+
"""
|
153 |
+
Args:
|
154 |
+
mixture: [M, T], M is batch size, T is #samples
|
155 |
+
Returns:
|
156 |
+
mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
|
157 |
+
"""
|
158 |
+
mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
|
159 |
+
return mixture_w
|
160 |
+
|
161 |
+
|
162 |
+
class Decoder(nn.Module):
|
163 |
+
def __init__(self, N, L, audio_channels):
|
164 |
+
super(Decoder, self).__init__()
|
165 |
+
# Hyper-parameter
|
166 |
+
self.N, self.L = N, L
|
167 |
+
self.audio_channels = audio_channels
|
168 |
+
# Components
|
169 |
+
self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
|
170 |
+
|
171 |
+
def forward(self, mixture_w, est_mask):
|
172 |
+
"""
|
173 |
+
Args:
|
174 |
+
mixture_w: [M, N, K]
|
175 |
+
est_mask: [M, C, N, K]
|
176 |
+
Returns:
|
177 |
+
est_source: [M, C, T]
|
178 |
+
"""
|
179 |
+
# D = W * M
|
180 |
+
source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
|
181 |
+
source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
|
182 |
+
# S = DV
|
183 |
+
est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
|
184 |
+
m, c, k, _ = est_source.size()
|
185 |
+
est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
|
186 |
+
est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
|
187 |
+
return est_source
|
188 |
+
|
189 |
+
|
190 |
+
class TemporalConvNet(nn.Module):
|
191 |
+
def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
|
192 |
+
"""
|
193 |
+
Args:
|
194 |
+
N: Number of filters in autoencoder
|
195 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
196 |
+
H: Number of channels in convolutional blocks
|
197 |
+
P: Kernel size in convolutional blocks
|
198 |
+
X: Number of convolutional blocks in each repeat
|
199 |
+
R: Number of repeats
|
200 |
+
C: Number of speakers
|
201 |
+
norm_type: BN, gLN, cLN
|
202 |
+
causal: causal or non-causal
|
203 |
+
mask_nonlinear: use which non-linear function to generate mask
|
204 |
+
"""
|
205 |
+
super(TemporalConvNet, self).__init__()
|
206 |
+
# Hyper-parameter
|
207 |
+
self.C = C
|
208 |
+
self.mask_nonlinear = mask_nonlinear
|
209 |
+
# Components
|
210 |
+
# [M, N, K] -> [M, N, K]
|
211 |
+
layer_norm = ChannelwiseLayerNorm(N)
|
212 |
+
# [M, N, K] -> [M, B, K]
|
213 |
+
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
|
214 |
+
# [M, B, K] -> [M, B, K]
|
215 |
+
repeats = []
|
216 |
+
for r in range(R):
|
217 |
+
blocks = []
|
218 |
+
for x in range(X):
|
219 |
+
dilation = 2**x
|
220 |
+
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
|
221 |
+
blocks += [
|
222 |
+
TemporalBlock(B,
|
223 |
+
H,
|
224 |
+
P,
|
225 |
+
stride=1,
|
226 |
+
padding=padding,
|
227 |
+
dilation=dilation,
|
228 |
+
norm_type=norm_type,
|
229 |
+
causal=causal)
|
230 |
+
]
|
231 |
+
repeats += [nn.Sequential(*blocks)]
|
232 |
+
temporal_conv_net = nn.Sequential(*repeats)
|
233 |
+
# [M, B, K] -> [M, C*N, K]
|
234 |
+
mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
|
235 |
+
# Put together
|
236 |
+
self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
|
237 |
+
mask_conv1x1)
|
238 |
+
|
239 |
+
def forward(self, mixture_w):
|
240 |
+
"""
|
241 |
+
Keep this API same with TasNet
|
242 |
+
Args:
|
243 |
+
mixture_w: [M, N, K], M is batch size
|
244 |
+
returns:
|
245 |
+
est_mask: [M, C, N, K]
|
246 |
+
"""
|
247 |
+
M, N, K = mixture_w.size()
|
248 |
+
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
|
249 |
+
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
|
250 |
+
if self.mask_nonlinear == 'softmax':
|
251 |
+
est_mask = F.softmax(score, dim=1)
|
252 |
+
elif self.mask_nonlinear == 'relu':
|
253 |
+
est_mask = F.relu(score)
|
254 |
+
else:
|
255 |
+
raise ValueError("Unsupported mask non-linear function")
|
256 |
+
return est_mask
|
257 |
+
|
258 |
+
|
259 |
+
class TemporalBlock(nn.Module):
|
260 |
+
def __init__(self,
|
261 |
+
in_channels,
|
262 |
+
out_channels,
|
263 |
+
kernel_size,
|
264 |
+
stride,
|
265 |
+
padding,
|
266 |
+
dilation,
|
267 |
+
norm_type="gLN",
|
268 |
+
causal=False):
|
269 |
+
super(TemporalBlock, self).__init__()
|
270 |
+
# [M, B, K] -> [M, H, K]
|
271 |
+
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
272 |
+
prelu = nn.PReLU()
|
273 |
+
norm = chose_norm(norm_type, out_channels)
|
274 |
+
# [M, H, K] -> [M, B, K]
|
275 |
+
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
|
276 |
+
dilation, norm_type, causal)
|
277 |
+
# Put together
|
278 |
+
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
"""
|
282 |
+
Args:
|
283 |
+
x: [M, B, K]
|
284 |
+
Returns:
|
285 |
+
[M, B, K]
|
286 |
+
"""
|
287 |
+
residual = x
|
288 |
+
out = self.net(x)
|
289 |
+
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
|
290 |
+
return out + residual # look like w/o F.relu is better than w/ F.relu
|
291 |
+
# return F.relu(out + residual)
|
292 |
+
|
293 |
+
|
294 |
+
class DepthwiseSeparableConv(nn.Module):
|
295 |
+
def __init__(self,
|
296 |
+
in_channels,
|
297 |
+
out_channels,
|
298 |
+
kernel_size,
|
299 |
+
stride,
|
300 |
+
padding,
|
301 |
+
dilation,
|
302 |
+
norm_type="gLN",
|
303 |
+
causal=False):
|
304 |
+
super(DepthwiseSeparableConv, self).__init__()
|
305 |
+
# Use `groups` option to implement depthwise convolution
|
306 |
+
# [M, H, K] -> [M, H, K]
|
307 |
+
depthwise_conv = nn.Conv1d(in_channels,
|
308 |
+
in_channels,
|
309 |
+
kernel_size,
|
310 |
+
stride=stride,
|
311 |
+
padding=padding,
|
312 |
+
dilation=dilation,
|
313 |
+
groups=in_channels,
|
314 |
+
bias=False)
|
315 |
+
if causal:
|
316 |
+
chomp = Chomp1d(padding)
|
317 |
+
prelu = nn.PReLU()
|
318 |
+
norm = chose_norm(norm_type, in_channels)
|
319 |
+
# [M, H, K] -> [M, B, K]
|
320 |
+
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
321 |
+
# Put together
|
322 |
+
if causal:
|
323 |
+
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
|
324 |
+
else:
|
325 |
+
self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
|
326 |
+
|
327 |
+
def forward(self, x):
|
328 |
+
"""
|
329 |
+
Args:
|
330 |
+
x: [M, H, K]
|
331 |
+
Returns:
|
332 |
+
result: [M, B, K]
|
333 |
+
"""
|
334 |
+
return self.net(x)
|
335 |
+
|
336 |
+
|
337 |
+
class Chomp1d(nn.Module):
|
338 |
+
"""To ensure the output length is the same as the input.
|
339 |
+
"""
|
340 |
+
def __init__(self, chomp_size):
|
341 |
+
super(Chomp1d, self).__init__()
|
342 |
+
self.chomp_size = chomp_size
|
343 |
+
|
344 |
+
def forward(self, x):
|
345 |
+
"""
|
346 |
+
Args:
|
347 |
+
x: [M, H, Kpad]
|
348 |
+
Returns:
|
349 |
+
[M, H, K]
|
350 |
+
"""
|
351 |
+
return x[:, :, :-self.chomp_size].contiguous()
|
352 |
+
|
353 |
+
|
354 |
+
def chose_norm(norm_type, channel_size):
|
355 |
+
"""The input of normlization will be (M, C, K), where M is batch size,
|
356 |
+
C is channel size and K is sequence length.
|
357 |
+
"""
|
358 |
+
if norm_type == "gLN":
|
359 |
+
return GlobalLayerNorm(channel_size)
|
360 |
+
elif norm_type == "cLN":
|
361 |
+
return ChannelwiseLayerNorm(channel_size)
|
362 |
+
elif norm_type == "id":
|
363 |
+
return nn.Identity()
|
364 |
+
else: # norm_type == "BN":
|
365 |
+
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
366 |
+
# along M and K, so this BN usage is right.
|
367 |
+
return nn.BatchNorm1d(channel_size)
|
368 |
+
|
369 |
+
|
370 |
+
# TODO: Use nn.LayerNorm to impl cLN to speed up
|
371 |
+
class ChannelwiseLayerNorm(nn.Module):
|
372 |
+
"""Channel-wise Layer Normalization (cLN)"""
|
373 |
+
def __init__(self, channel_size):
|
374 |
+
super(ChannelwiseLayerNorm, self).__init__()
|
375 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
376 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
377 |
+
self.reset_parameters()
|
378 |
+
|
379 |
+
def reset_parameters(self):
|
380 |
+
self.gamma.data.fill_(1)
|
381 |
+
self.beta.data.zero_()
|
382 |
+
|
383 |
+
def forward(self, y):
|
384 |
+
"""
|
385 |
+
Args:
|
386 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
387 |
+
Returns:
|
388 |
+
cLN_y: [M, N, K]
|
389 |
+
"""
|
390 |
+
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
|
391 |
+
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
|
392 |
+
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
393 |
+
return cLN_y
|
394 |
+
|
395 |
+
|
396 |
+
class GlobalLayerNorm(nn.Module):
|
397 |
+
"""Global Layer Normalization (gLN)"""
|
398 |
+
def __init__(self, channel_size):
|
399 |
+
super(GlobalLayerNorm, self).__init__()
|
400 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
401 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
402 |
+
self.reset_parameters()
|
403 |
+
|
404 |
+
def reset_parameters(self):
|
405 |
+
self.gamma.data.fill_(1)
|
406 |
+
self.beta.data.zero_()
|
407 |
+
|
408 |
+
def forward(self, y):
|
409 |
+
"""
|
410 |
+
Args:
|
411 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
412 |
+
Returns:
|
413 |
+
gLN_y: [M, N, K]
|
414 |
+
"""
|
415 |
+
# TODO: in torch 1.0, torch.mean() support dim list
|
416 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
|
417 |
+
var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
418 |
+
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
419 |
+
return gLN_y
|
420 |
+
|
421 |
+
|
422 |
+
if __name__ == "__main__":
|
423 |
+
torch.manual_seed(123)
|
424 |
+
M, N, L, T = 2, 3, 4, 12
|
425 |
+
K = 2 * T // L - 1
|
426 |
+
B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
|
427 |
+
mixture = torch.randint(3, (M, T))
|
428 |
+
# test Encoder
|
429 |
+
encoder = Encoder(L, N)
|
430 |
+
encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
|
431 |
+
mixture_w = encoder(mixture)
|
432 |
+
print('mixture', mixture)
|
433 |
+
print('U', encoder.conv1d_U.weight)
|
434 |
+
print('mixture_w', mixture_w)
|
435 |
+
print('mixture_w size', mixture_w.size())
|
436 |
+
|
437 |
+
# test TemporalConvNet
|
438 |
+
separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
|
439 |
+
est_mask = separator(mixture_w)
|
440 |
+
print('est_mask', est_mask)
|
441 |
+
|
442 |
+
# test Decoder
|
443 |
+
decoder = Decoder(N, L)
|
444 |
+
est_mask = torch.randint(2, (B, K, C, N))
|
445 |
+
est_source = decoder(mixture_w, est_mask)
|
446 |
+
print('est_source', est_source)
|
447 |
+
|
448 |
+
# test Conv-TasNet
|
449 |
+
conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
|
450 |
+
est_source = conv_tasnet(mixture)
|
451 |
+
print('est_source', est_source)
|
452 |
+
print('est_source size', est_source.size())
|
demucs/transformer.py
ADDED
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-present, Meta, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# First author is Simon Rouard.
|
7 |
+
|
8 |
+
import random
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import numpy as np
|
15 |
+
import math
|
16 |
+
from einops import rearrange
|
17 |
+
|
18 |
+
|
19 |
+
def create_sin_embedding(
|
20 |
+
length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
|
21 |
+
):
|
22 |
+
# We aim for TBC format
|
23 |
+
assert dim % 2 == 0
|
24 |
+
pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
|
25 |
+
half_dim = dim // 2
|
26 |
+
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
27 |
+
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
28 |
+
return torch.cat(
|
29 |
+
[
|
30 |
+
torch.cos(phase),
|
31 |
+
torch.sin(phase),
|
32 |
+
],
|
33 |
+
dim=-1,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
|
38 |
+
"""
|
39 |
+
:param d_model: dimension of the model
|
40 |
+
:param height: height of the positions
|
41 |
+
:param width: width of the positions
|
42 |
+
:return: d_model*height*width position matrix
|
43 |
+
"""
|
44 |
+
if d_model % 4 != 0:
|
45 |
+
raise ValueError(
|
46 |
+
"Cannot use sin/cos positional encoding with "
|
47 |
+
"odd dimension (got dim={:d})".format(d_model)
|
48 |
+
)
|
49 |
+
pe = torch.zeros(d_model, height, width)
|
50 |
+
# Each dimension use half of d_model
|
51 |
+
d_model = int(d_model / 2)
|
52 |
+
div_term = torch.exp(
|
53 |
+
torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
|
54 |
+
)
|
55 |
+
pos_w = torch.arange(0.0, width).unsqueeze(1)
|
56 |
+
pos_h = torch.arange(0.0, height).unsqueeze(1)
|
57 |
+
pe[0:d_model:2, :, :] = (
|
58 |
+
torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
59 |
+
)
|
60 |
+
pe[1:d_model:2, :, :] = (
|
61 |
+
torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
62 |
+
)
|
63 |
+
pe[d_model::2, :, :] = (
|
64 |
+
torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
65 |
+
)
|
66 |
+
pe[d_model + 1:: 2, :, :] = (
|
67 |
+
torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
68 |
+
)
|
69 |
+
|
70 |
+
return pe[None, :].to(device)
|
71 |
+
|
72 |
+
|
73 |
+
def create_sin_embedding_cape(
|
74 |
+
length: int,
|
75 |
+
dim: int,
|
76 |
+
batch_size: int,
|
77 |
+
mean_normalize: bool,
|
78 |
+
augment: bool, # True during training
|
79 |
+
max_global_shift: float = 0.0, # delta max
|
80 |
+
max_local_shift: float = 0.0, # epsilon max
|
81 |
+
max_scale: float = 1.0,
|
82 |
+
device: str = "cpu",
|
83 |
+
max_period: float = 10000.0,
|
84 |
+
):
|
85 |
+
# We aim for TBC format
|
86 |
+
assert dim % 2 == 0
|
87 |
+
pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
|
88 |
+
pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
|
89 |
+
if mean_normalize:
|
90 |
+
pos -= torch.nanmean(pos, dim=0, keepdim=True)
|
91 |
+
|
92 |
+
if augment:
|
93 |
+
delta = np.random.uniform(
|
94 |
+
-max_global_shift, +max_global_shift, size=[1, batch_size, 1]
|
95 |
+
)
|
96 |
+
delta_local = np.random.uniform(
|
97 |
+
-max_local_shift, +max_local_shift, size=[length, batch_size, 1]
|
98 |
+
)
|
99 |
+
log_lambdas = np.random.uniform(
|
100 |
+
-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
|
101 |
+
)
|
102 |
+
pos = (pos + delta + delta_local) * np.exp(log_lambdas)
|
103 |
+
|
104 |
+
pos = pos.to(device)
|
105 |
+
|
106 |
+
half_dim = dim // 2
|
107 |
+
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
108 |
+
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
109 |
+
return torch.cat(
|
110 |
+
[
|
111 |
+
torch.cos(phase),
|
112 |
+
torch.sin(phase),
|
113 |
+
],
|
114 |
+
dim=-1,
|
115 |
+
).float()
|
116 |
+
|
117 |
+
|
118 |
+
def get_causal_mask(length):
|
119 |
+
pos = torch.arange(length)
|
120 |
+
return pos > pos[:, None]
|
121 |
+
|
122 |
+
|
123 |
+
def get_elementary_mask(
|
124 |
+
T1,
|
125 |
+
T2,
|
126 |
+
mask_type,
|
127 |
+
sparse_attn_window,
|
128 |
+
global_window,
|
129 |
+
mask_random_seed,
|
130 |
+
sparsity,
|
131 |
+
device,
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
When the input of the Decoder has length T1 and the output T2
|
135 |
+
The mask matrix has shape (T2, T1)
|
136 |
+
"""
|
137 |
+
assert mask_type in ["diag", "jmask", "random", "global"]
|
138 |
+
|
139 |
+
if mask_type == "global":
|
140 |
+
mask = torch.zeros(T2, T1, dtype=torch.bool)
|
141 |
+
mask[:, :global_window] = True
|
142 |
+
line_window = int(global_window * T2 / T1)
|
143 |
+
mask[:line_window, :] = True
|
144 |
+
|
145 |
+
if mask_type == "diag":
|
146 |
+
|
147 |
+
mask = torch.zeros(T2, T1, dtype=torch.bool)
|
148 |
+
rows = torch.arange(T2)[:, None]
|
149 |
+
cols = (
|
150 |
+
(T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
|
151 |
+
.long()
|
152 |
+
.clamp(0, T1 - 1)
|
153 |
+
)
|
154 |
+
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
|
155 |
+
|
156 |
+
elif mask_type == "jmask":
|
157 |
+
mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
|
158 |
+
rows = torch.arange(T2 + 2)[:, None]
|
159 |
+
t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
|
160 |
+
t = (t * (t + 1) / 2).int()
|
161 |
+
t = torch.cat([-t.flip(0)[:-1], t])
|
162 |
+
cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
|
163 |
+
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
|
164 |
+
mask = mask[1:-1, 1:-1]
|
165 |
+
|
166 |
+
elif mask_type == "random":
|
167 |
+
gene = torch.Generator(device=device)
|
168 |
+
gene.manual_seed(mask_random_seed)
|
169 |
+
mask = (
|
170 |
+
torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1)
|
171 |
+
> sparsity
|
172 |
+
)
|
173 |
+
|
174 |
+
mask = mask.to(device)
|
175 |
+
return mask
|
176 |
+
|
177 |
+
|
178 |
+
def get_mask(
|
179 |
+
T1,
|
180 |
+
T2,
|
181 |
+
mask_type,
|
182 |
+
sparse_attn_window,
|
183 |
+
global_window,
|
184 |
+
mask_random_seed,
|
185 |
+
sparsity,
|
186 |
+
device,
|
187 |
+
):
|
188 |
+
"""
|
189 |
+
Return a SparseCSRTensor mask that is a combination of elementary masks
|
190 |
+
mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
|
191 |
+
"""
|
192 |
+
from xformers.sparse import SparseCSRTensor
|
193 |
+
# create a list
|
194 |
+
mask_types = mask_type.split("_")
|
195 |
+
|
196 |
+
all_masks = [
|
197 |
+
get_elementary_mask(
|
198 |
+
T1,
|
199 |
+
T2,
|
200 |
+
mask,
|
201 |
+
sparse_attn_window,
|
202 |
+
global_window,
|
203 |
+
mask_random_seed,
|
204 |
+
sparsity,
|
205 |
+
device,
|
206 |
+
)
|
207 |
+
for mask in mask_types
|
208 |
+
]
|
209 |
+
|
210 |
+
final_mask = torch.stack(all_masks).sum(axis=0) > 0
|
211 |
+
|
212 |
+
return SparseCSRTensor.from_dense(final_mask[None])
|
213 |
+
|
214 |
+
|
215 |
+
class ScaledEmbedding(nn.Module):
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
num_embeddings: int,
|
219 |
+
embedding_dim: int,
|
220 |
+
scale: float = 1.0,
|
221 |
+
boost: float = 3.0,
|
222 |
+
):
|
223 |
+
super().__init__()
|
224 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
225 |
+
self.embedding.weight.data *= scale / boost
|
226 |
+
self.boost = boost
|
227 |
+
|
228 |
+
@property
|
229 |
+
def weight(self):
|
230 |
+
return self.embedding.weight * self.boost
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
return self.embedding(x) * self.boost
|
234 |
+
|
235 |
+
|
236 |
+
class LayerScale(nn.Module):
|
237 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
238 |
+
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
239 |
+
"""
|
240 |
+
|
241 |
+
def __init__(self, channels: int, init: float = 0, channel_last=False):
|
242 |
+
"""
|
243 |
+
channel_last = False corresponds to (B, C, T) tensors
|
244 |
+
channel_last = True corresponds to (T, B, C) tensors
|
245 |
+
"""
|
246 |
+
super().__init__()
|
247 |
+
self.channel_last = channel_last
|
248 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
249 |
+
self.scale.data[:] = init
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
if self.channel_last:
|
253 |
+
return self.scale * x
|
254 |
+
else:
|
255 |
+
return self.scale[:, None] * x
|
256 |
+
|
257 |
+
|
258 |
+
class MyGroupNorm(nn.GroupNorm):
|
259 |
+
def __init__(self, *args, **kwargs):
|
260 |
+
super().__init__(*args, **kwargs)
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
"""
|
264 |
+
x: (B, T, C)
|
265 |
+
if num_groups=1: Normalisation on all T and C together for each B
|
266 |
+
"""
|
267 |
+
x = x.transpose(1, 2)
|
268 |
+
return super().forward(x).transpose(1, 2)
|
269 |
+
|
270 |
+
|
271 |
+
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
d_model,
|
275 |
+
nhead,
|
276 |
+
dim_feedforward=2048,
|
277 |
+
dropout=0.1,
|
278 |
+
activation=F.relu,
|
279 |
+
group_norm=0,
|
280 |
+
norm_first=False,
|
281 |
+
norm_out=False,
|
282 |
+
layer_norm_eps=1e-5,
|
283 |
+
layer_scale=False,
|
284 |
+
init_values=1e-4,
|
285 |
+
device=None,
|
286 |
+
dtype=None,
|
287 |
+
sparse=False,
|
288 |
+
mask_type="diag",
|
289 |
+
mask_random_seed=42,
|
290 |
+
sparse_attn_window=500,
|
291 |
+
global_window=50,
|
292 |
+
auto_sparsity=False,
|
293 |
+
sparsity=0.95,
|
294 |
+
batch_first=False,
|
295 |
+
):
|
296 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
297 |
+
super().__init__(
|
298 |
+
d_model=d_model,
|
299 |
+
nhead=nhead,
|
300 |
+
dim_feedforward=dim_feedforward,
|
301 |
+
dropout=dropout,
|
302 |
+
activation=activation,
|
303 |
+
layer_norm_eps=layer_norm_eps,
|
304 |
+
batch_first=batch_first,
|
305 |
+
norm_first=norm_first,
|
306 |
+
device=device,
|
307 |
+
dtype=dtype,
|
308 |
+
)
|
309 |
+
self.sparse = sparse
|
310 |
+
self.auto_sparsity = auto_sparsity
|
311 |
+
if sparse:
|
312 |
+
if not auto_sparsity:
|
313 |
+
self.mask_type = mask_type
|
314 |
+
self.sparse_attn_window = sparse_attn_window
|
315 |
+
self.global_window = global_window
|
316 |
+
self.sparsity = sparsity
|
317 |
+
if group_norm:
|
318 |
+
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
319 |
+
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
320 |
+
|
321 |
+
self.norm_out = None
|
322 |
+
if self.norm_first & norm_out:
|
323 |
+
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
324 |
+
self.gamma_1 = (
|
325 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
326 |
+
)
|
327 |
+
self.gamma_2 = (
|
328 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
329 |
+
)
|
330 |
+
|
331 |
+
if sparse:
|
332 |
+
self.self_attn = MultiheadAttention(
|
333 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first,
|
334 |
+
auto_sparsity=sparsity if auto_sparsity else 0,
|
335 |
+
)
|
336 |
+
self.__setattr__("src_mask", torch.zeros(1, 1))
|
337 |
+
self.mask_random_seed = mask_random_seed
|
338 |
+
|
339 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
340 |
+
"""
|
341 |
+
if batch_first = False, src shape is (T, B, C)
|
342 |
+
the case where batch_first=True is not covered
|
343 |
+
"""
|
344 |
+
device = src.device
|
345 |
+
x = src
|
346 |
+
T, B, C = x.shape
|
347 |
+
if self.sparse and not self.auto_sparsity:
|
348 |
+
assert src_mask is None
|
349 |
+
src_mask = self.src_mask
|
350 |
+
if src_mask.shape[-1] != T:
|
351 |
+
src_mask = get_mask(
|
352 |
+
T,
|
353 |
+
T,
|
354 |
+
self.mask_type,
|
355 |
+
self.sparse_attn_window,
|
356 |
+
self.global_window,
|
357 |
+
self.mask_random_seed,
|
358 |
+
self.sparsity,
|
359 |
+
device,
|
360 |
+
)
|
361 |
+
self.__setattr__("src_mask", src_mask)
|
362 |
+
|
363 |
+
if self.norm_first:
|
364 |
+
x = x + self.gamma_1(
|
365 |
+
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
366 |
+
)
|
367 |
+
x = x + self.gamma_2(self._ff_block(self.norm2(x)))
|
368 |
+
|
369 |
+
if self.norm_out:
|
370 |
+
x = self.norm_out(x)
|
371 |
+
else:
|
372 |
+
x = self.norm1(
|
373 |
+
x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
|
374 |
+
)
|
375 |
+
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
376 |
+
|
377 |
+
return x
|
378 |
+
|
379 |
+
|
380 |
+
class CrossTransformerEncoderLayer(nn.Module):
|
381 |
+
def __init__(
|
382 |
+
self,
|
383 |
+
d_model: int,
|
384 |
+
nhead: int,
|
385 |
+
dim_feedforward: int = 2048,
|
386 |
+
dropout: float = 0.1,
|
387 |
+
activation=F.relu,
|
388 |
+
layer_norm_eps: float = 1e-5,
|
389 |
+
layer_scale: bool = False,
|
390 |
+
init_values: float = 1e-4,
|
391 |
+
norm_first: bool = False,
|
392 |
+
group_norm: bool = False,
|
393 |
+
norm_out: bool = False,
|
394 |
+
sparse=False,
|
395 |
+
mask_type="diag",
|
396 |
+
mask_random_seed=42,
|
397 |
+
sparse_attn_window=500,
|
398 |
+
global_window=50,
|
399 |
+
sparsity=0.95,
|
400 |
+
auto_sparsity=None,
|
401 |
+
device=None,
|
402 |
+
dtype=None,
|
403 |
+
batch_first=False,
|
404 |
+
):
|
405 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
406 |
+
super().__init__()
|
407 |
+
|
408 |
+
self.sparse = sparse
|
409 |
+
self.auto_sparsity = auto_sparsity
|
410 |
+
if sparse:
|
411 |
+
if not auto_sparsity:
|
412 |
+
self.mask_type = mask_type
|
413 |
+
self.sparse_attn_window = sparse_attn_window
|
414 |
+
self.global_window = global_window
|
415 |
+
self.sparsity = sparsity
|
416 |
+
|
417 |
+
self.cross_attn: nn.Module
|
418 |
+
self.cross_attn = nn.MultiheadAttention(
|
419 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first)
|
420 |
+
# Implementation of Feedforward model
|
421 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
422 |
+
self.dropout = nn.Dropout(dropout)
|
423 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
|
424 |
+
|
425 |
+
self.norm_first = norm_first
|
426 |
+
self.norm1: nn.Module
|
427 |
+
self.norm2: nn.Module
|
428 |
+
self.norm3: nn.Module
|
429 |
+
if group_norm:
|
430 |
+
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
431 |
+
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
432 |
+
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
433 |
+
else:
|
434 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
435 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
436 |
+
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
437 |
+
|
438 |
+
self.norm_out = None
|
439 |
+
if self.norm_first & norm_out:
|
440 |
+
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
441 |
+
|
442 |
+
self.gamma_1 = (
|
443 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
444 |
+
)
|
445 |
+
self.gamma_2 = (
|
446 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
447 |
+
)
|
448 |
+
|
449 |
+
self.dropout1 = nn.Dropout(dropout)
|
450 |
+
self.dropout2 = nn.Dropout(dropout)
|
451 |
+
|
452 |
+
# Legacy string support for activation function.
|
453 |
+
if isinstance(activation, str):
|
454 |
+
self.activation = self._get_activation_fn(activation)
|
455 |
+
else:
|
456 |
+
self.activation = activation
|
457 |
+
|
458 |
+
if sparse:
|
459 |
+
self.cross_attn = MultiheadAttention(
|
460 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first,
|
461 |
+
auto_sparsity=sparsity if auto_sparsity else 0)
|
462 |
+
if not auto_sparsity:
|
463 |
+
self.__setattr__("mask", torch.zeros(1, 1))
|
464 |
+
self.mask_random_seed = mask_random_seed
|
465 |
+
|
466 |
+
def forward(self, q, k, mask=None):
|
467 |
+
"""
|
468 |
+
Args:
|
469 |
+
q: tensor of shape (T, B, C)
|
470 |
+
k: tensor of shape (S, B, C)
|
471 |
+
mask: tensor of shape (T, S)
|
472 |
+
|
473 |
+
"""
|
474 |
+
device = q.device
|
475 |
+
T, B, C = q.shape
|
476 |
+
S, B, C = k.shape
|
477 |
+
if self.sparse and not self.auto_sparsity:
|
478 |
+
assert mask is None
|
479 |
+
mask = self.mask
|
480 |
+
if mask.shape[-1] != S or mask.shape[-2] != T:
|
481 |
+
mask = get_mask(
|
482 |
+
S,
|
483 |
+
T,
|
484 |
+
self.mask_type,
|
485 |
+
self.sparse_attn_window,
|
486 |
+
self.global_window,
|
487 |
+
self.mask_random_seed,
|
488 |
+
self.sparsity,
|
489 |
+
device,
|
490 |
+
)
|
491 |
+
self.__setattr__("mask", mask)
|
492 |
+
|
493 |
+
if self.norm_first:
|
494 |
+
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
|
495 |
+
x = x + self.gamma_2(self._ff_block(self.norm3(x)))
|
496 |
+
if self.norm_out:
|
497 |
+
x = self.norm_out(x)
|
498 |
+
else:
|
499 |
+
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
|
500 |
+
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
501 |
+
|
502 |
+
return x
|
503 |
+
|
504 |
+
# self-attention block
|
505 |
+
def _ca_block(self, q, k, attn_mask=None):
|
506 |
+
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
|
507 |
+
return self.dropout1(x)
|
508 |
+
|
509 |
+
# feed forward block
|
510 |
+
def _ff_block(self, x):
|
511 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
512 |
+
return self.dropout2(x)
|
513 |
+
|
514 |
+
def _get_activation_fn(self, activation):
|
515 |
+
if activation == "relu":
|
516 |
+
return F.relu
|
517 |
+
elif activation == "gelu":
|
518 |
+
return F.gelu
|
519 |
+
|
520 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
521 |
+
|
522 |
+
|
523 |
+
# ----------------- MULTI-BLOCKS MODELS: -----------------------
|
524 |
+
|
525 |
+
|
526 |
+
class CrossTransformerEncoder(nn.Module):
|
527 |
+
def __init__(
|
528 |
+
self,
|
529 |
+
dim: int,
|
530 |
+
emb: str = "sin",
|
531 |
+
hidden_scale: float = 4.0,
|
532 |
+
num_heads: int = 8,
|
533 |
+
num_layers: int = 6,
|
534 |
+
cross_first: bool = False,
|
535 |
+
dropout: float = 0.0,
|
536 |
+
max_positions: int = 1000,
|
537 |
+
norm_in: bool = True,
|
538 |
+
norm_in_group: bool = False,
|
539 |
+
group_norm: int = False,
|
540 |
+
norm_first: bool = False,
|
541 |
+
norm_out: bool = False,
|
542 |
+
max_period: float = 10000.0,
|
543 |
+
weight_decay: float = 0.0,
|
544 |
+
lr: tp.Optional[float] = None,
|
545 |
+
layer_scale: bool = False,
|
546 |
+
gelu: bool = True,
|
547 |
+
sin_random_shift: int = 0,
|
548 |
+
weight_pos_embed: float = 1.0,
|
549 |
+
cape_mean_normalize: bool = True,
|
550 |
+
cape_augment: bool = True,
|
551 |
+
cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
|
552 |
+
sparse_self_attn: bool = False,
|
553 |
+
sparse_cross_attn: bool = False,
|
554 |
+
mask_type: str = "diag",
|
555 |
+
mask_random_seed: int = 42,
|
556 |
+
sparse_attn_window: int = 500,
|
557 |
+
global_window: int = 50,
|
558 |
+
auto_sparsity: bool = False,
|
559 |
+
sparsity: float = 0.95,
|
560 |
+
):
|
561 |
+
super().__init__()
|
562 |
+
"""
|
563 |
+
"""
|
564 |
+
assert dim % num_heads == 0
|
565 |
+
|
566 |
+
hidden_dim = int(dim * hidden_scale)
|
567 |
+
|
568 |
+
self.num_layers = num_layers
|
569 |
+
# classic parity = 1 means that if idx%2 == 1 there is a
|
570 |
+
# classical encoder else there is a cross encoder
|
571 |
+
self.classic_parity = 1 if cross_first else 0
|
572 |
+
self.emb = emb
|
573 |
+
self.max_period = max_period
|
574 |
+
self.weight_decay = weight_decay
|
575 |
+
self.weight_pos_embed = weight_pos_embed
|
576 |
+
self.sin_random_shift = sin_random_shift
|
577 |
+
if emb == "cape":
|
578 |
+
self.cape_mean_normalize = cape_mean_normalize
|
579 |
+
self.cape_augment = cape_augment
|
580 |
+
self.cape_glob_loc_scale = cape_glob_loc_scale
|
581 |
+
if emb == "scaled":
|
582 |
+
self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
|
583 |
+
|
584 |
+
self.lr = lr
|
585 |
+
|
586 |
+
activation: tp.Any = F.gelu if gelu else F.relu
|
587 |
+
|
588 |
+
self.norm_in: nn.Module
|
589 |
+
self.norm_in_t: nn.Module
|
590 |
+
if norm_in:
|
591 |
+
self.norm_in = nn.LayerNorm(dim)
|
592 |
+
self.norm_in_t = nn.LayerNorm(dim)
|
593 |
+
elif norm_in_group:
|
594 |
+
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
|
595 |
+
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
|
596 |
+
else:
|
597 |
+
self.norm_in = nn.Identity()
|
598 |
+
self.norm_in_t = nn.Identity()
|
599 |
+
|
600 |
+
# spectrogram layers
|
601 |
+
self.layers = nn.ModuleList()
|
602 |
+
# temporal layers
|
603 |
+
self.layers_t = nn.ModuleList()
|
604 |
+
|
605 |
+
kwargs_common = {
|
606 |
+
"d_model": dim,
|
607 |
+
"nhead": num_heads,
|
608 |
+
"dim_feedforward": hidden_dim,
|
609 |
+
"dropout": dropout,
|
610 |
+
"activation": activation,
|
611 |
+
"group_norm": group_norm,
|
612 |
+
"norm_first": norm_first,
|
613 |
+
"norm_out": norm_out,
|
614 |
+
"layer_scale": layer_scale,
|
615 |
+
"mask_type": mask_type,
|
616 |
+
"mask_random_seed": mask_random_seed,
|
617 |
+
"sparse_attn_window": sparse_attn_window,
|
618 |
+
"global_window": global_window,
|
619 |
+
"sparsity": sparsity,
|
620 |
+
"auto_sparsity": auto_sparsity,
|
621 |
+
"batch_first": True,
|
622 |
+
}
|
623 |
+
|
624 |
+
kwargs_classic_encoder = dict(kwargs_common)
|
625 |
+
kwargs_classic_encoder.update({
|
626 |
+
"sparse": sparse_self_attn,
|
627 |
+
})
|
628 |
+
kwargs_cross_encoder = dict(kwargs_common)
|
629 |
+
kwargs_cross_encoder.update({
|
630 |
+
"sparse": sparse_cross_attn,
|
631 |
+
})
|
632 |
+
|
633 |
+
for idx in range(num_layers):
|
634 |
+
if idx % 2 == self.classic_parity:
|
635 |
+
|
636 |
+
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
637 |
+
self.layers_t.append(
|
638 |
+
MyTransformerEncoderLayer(**kwargs_classic_encoder)
|
639 |
+
)
|
640 |
+
|
641 |
+
else:
|
642 |
+
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
643 |
+
|
644 |
+
self.layers_t.append(
|
645 |
+
CrossTransformerEncoderLayer(**kwargs_cross_encoder)
|
646 |
+
)
|
647 |
+
|
648 |
+
def forward(self, x, xt):
|
649 |
+
B, C, Fr, T1 = x.shape
|
650 |
+
pos_emb_2d = create_2d_sin_embedding(
|
651 |
+
C, Fr, T1, x.device, self.max_period
|
652 |
+
) # (1, C, Fr, T1)
|
653 |
+
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
|
654 |
+
x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
|
655 |
+
x = self.norm_in(x)
|
656 |
+
x = x + self.weight_pos_embed * pos_emb_2d
|
657 |
+
|
658 |
+
B, C, T2 = xt.shape
|
659 |
+
xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
|
660 |
+
pos_emb = self._get_pos_embedding(T2, B, C, x.device)
|
661 |
+
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
|
662 |
+
xt = self.norm_in_t(xt)
|
663 |
+
xt = xt + self.weight_pos_embed * pos_emb
|
664 |
+
|
665 |
+
for idx in range(self.num_layers):
|
666 |
+
if idx % 2 == self.classic_parity:
|
667 |
+
x = self.layers[idx](x)
|
668 |
+
xt = self.layers_t[idx](xt)
|
669 |
+
else:
|
670 |
+
old_x = x
|
671 |
+
x = self.layers[idx](x, xt)
|
672 |
+
xt = self.layers_t[idx](xt, old_x)
|
673 |
+
|
674 |
+
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
|
675 |
+
xt = rearrange(xt, "b t2 c -> b c t2")
|
676 |
+
return x, xt
|
677 |
+
|
678 |
+
def _get_pos_embedding(self, T, B, C, device):
|
679 |
+
if self.emb == "sin":
|
680 |
+
shift = random.randrange(self.sin_random_shift + 1)
|
681 |
+
pos_emb = create_sin_embedding(
|
682 |
+
T, C, shift=shift, device=device, max_period=self.max_period
|
683 |
+
)
|
684 |
+
elif self.emb == "cape":
|
685 |
+
if self.training:
|
686 |
+
pos_emb = create_sin_embedding_cape(
|
687 |
+
T,
|
688 |
+
C,
|
689 |
+
B,
|
690 |
+
device=device,
|
691 |
+
max_period=self.max_period,
|
692 |
+
mean_normalize=self.cape_mean_normalize,
|
693 |
+
augment=self.cape_augment,
|
694 |
+
max_global_shift=self.cape_glob_loc_scale[0],
|
695 |
+
max_local_shift=self.cape_glob_loc_scale[1],
|
696 |
+
max_scale=self.cape_glob_loc_scale[2],
|
697 |
+
)
|
698 |
+
else:
|
699 |
+
pos_emb = create_sin_embedding_cape(
|
700 |
+
T,
|
701 |
+
C,
|
702 |
+
B,
|
703 |
+
device=device,
|
704 |
+
max_period=self.max_period,
|
705 |
+
mean_normalize=self.cape_mean_normalize,
|
706 |
+
augment=False,
|
707 |
+
)
|
708 |
+
|
709 |
+
elif self.emb == "scaled":
|
710 |
+
pos = torch.arange(T, device=device)
|
711 |
+
pos_emb = self.position_embeddings(pos)[:, None]
|
712 |
+
|
713 |
+
return pos_emb
|
714 |
+
|
715 |
+
def make_optim_group(self):
|
716 |
+
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
|
717 |
+
if self.lr is not None:
|
718 |
+
group["lr"] = self.lr
|
719 |
+
return group
|
720 |
+
|
721 |
+
|
722 |
+
# Attention Modules
|
723 |
+
|
724 |
+
|
725 |
+
class MultiheadAttention(nn.Module):
|
726 |
+
def __init__(
|
727 |
+
self,
|
728 |
+
embed_dim,
|
729 |
+
num_heads,
|
730 |
+
dropout=0.0,
|
731 |
+
bias=True,
|
732 |
+
add_bias_kv=False,
|
733 |
+
add_zero_attn=False,
|
734 |
+
kdim=None,
|
735 |
+
vdim=None,
|
736 |
+
batch_first=False,
|
737 |
+
auto_sparsity=None,
|
738 |
+
):
|
739 |
+
super().__init__()
|
740 |
+
assert auto_sparsity is not None, "sanity check"
|
741 |
+
self.num_heads = num_heads
|
742 |
+
self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
|
743 |
+
self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
|
744 |
+
self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
|
745 |
+
self.attn_drop = torch.nn.Dropout(dropout)
|
746 |
+
self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
|
747 |
+
self.proj_drop = torch.nn.Dropout(dropout)
|
748 |
+
self.batch_first = batch_first
|
749 |
+
self.auto_sparsity = auto_sparsity
|
750 |
+
|
751 |
+
def forward(
|
752 |
+
self,
|
753 |
+
query,
|
754 |
+
key,
|
755 |
+
value,
|
756 |
+
key_padding_mask=None,
|
757 |
+
need_weights=True,
|
758 |
+
attn_mask=None,
|
759 |
+
average_attn_weights=True,
|
760 |
+
):
|
761 |
+
|
762 |
+
if not self.batch_first: # N, B, C
|
763 |
+
query = query.permute(1, 0, 2) # B, N_q, C
|
764 |
+
key = key.permute(1, 0, 2) # B, N_k, C
|
765 |
+
value = value.permute(1, 0, 2) # B, N_k, C
|
766 |
+
B, N_q, C = query.shape
|
767 |
+
B, N_k, C = key.shape
|
768 |
+
|
769 |
+
q = (
|
770 |
+
self.q(query)
|
771 |
+
.reshape(B, N_q, self.num_heads, C // self.num_heads)
|
772 |
+
.permute(0, 2, 1, 3)
|
773 |
+
)
|
774 |
+
q = q.flatten(0, 1)
|
775 |
+
k = (
|
776 |
+
self.k(key)
|
777 |
+
.reshape(B, N_k, self.num_heads, C // self.num_heads)
|
778 |
+
.permute(0, 2, 1, 3)
|
779 |
+
)
|
780 |
+
k = k.flatten(0, 1)
|
781 |
+
v = (
|
782 |
+
self.v(value)
|
783 |
+
.reshape(B, N_k, self.num_heads, C // self.num_heads)
|
784 |
+
.permute(0, 2, 1, 3)
|
785 |
+
)
|
786 |
+
v = v.flatten(0, 1)
|
787 |
+
|
788 |
+
if self.auto_sparsity:
|
789 |
+
assert attn_mask is None
|
790 |
+
x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
|
791 |
+
else:
|
792 |
+
x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
|
793 |
+
x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
|
794 |
+
|
795 |
+
x = x.transpose(1, 2).reshape(B, N_q, C)
|
796 |
+
x = self.proj(x)
|
797 |
+
x = self.proj_drop(x)
|
798 |
+
if not self.batch_first:
|
799 |
+
x = x.permute(1, 0, 2)
|
800 |
+
return x, None
|
801 |
+
|
802 |
+
|
803 |
+
def scaled_query_key_softmax(q, k, att_mask):
|
804 |
+
from xformers.ops import masked_matmul
|
805 |
+
q = q / (k.size(-1)) ** 0.5
|
806 |
+
att = masked_matmul(q, k.transpose(-2, -1), att_mask)
|
807 |
+
att = torch.nn.functional.softmax(att, -1)
|
808 |
+
return att
|
809 |
+
|
810 |
+
|
811 |
+
def scaled_dot_product_attention(q, k, v, att_mask, dropout):
|
812 |
+
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
|
813 |
+
att = dropout(att)
|
814 |
+
y = att @ v
|
815 |
+
return y
|
816 |
+
|
817 |
+
|
818 |
+
def _compute_buckets(x, R):
|
819 |
+
qq = torch.einsum('btf,bfhi->bhti', x, R)
|
820 |
+
qq = torch.cat([qq, -qq], dim=-1)
|
821 |
+
buckets = qq.argmax(dim=-1)
|
822 |
+
|
823 |
+
return buckets.permute(0, 2, 1).byte().contiguous()
|
824 |
+
|
825 |
+
|
826 |
+
def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
|
827 |
+
# assert False, "The code for the custom sparse kernel is not ready for release yet."
|
828 |
+
from xformers.ops import find_locations, sparse_memory_efficient_attention
|
829 |
+
n_hashes = 32
|
830 |
+
proj_size = 4
|
831 |
+
query, key, value = [x.contiguous() for x in [query, key, value]]
|
832 |
+
with torch.no_grad():
|
833 |
+
R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
|
834 |
+
bucket_query = _compute_buckets(query, R)
|
835 |
+
bucket_key = _compute_buckets(key, R)
|
836 |
+
row_offsets, column_indices = find_locations(
|
837 |
+
bucket_query, bucket_key, sparsity, infer_sparsity)
|
838 |
+
return sparse_memory_efficient_attention(
|
839 |
+
query, key, value, row_offsets, column_indices, attn_bias)
|
demucs/utils.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import defaultdict
|
8 |
+
from contextlib import contextmanager
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
import tempfile
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import errno
|
15 |
+
import functools
|
16 |
+
import hashlib
|
17 |
+
import inspect
|
18 |
+
import io
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
import socket
|
22 |
+
import tempfile
|
23 |
+
import warnings
|
24 |
+
import zlib
|
25 |
+
import tkinter as tk
|
26 |
+
|
27 |
+
from diffq import UniformQuantizer, DiffQuantizer
|
28 |
+
import torch as th
|
29 |
+
import tqdm
|
30 |
+
from torch import distributed
|
31 |
+
from torch.nn import functional as F
|
32 |
+
|
33 |
+
import torch
|
34 |
+
|
35 |
+
def unfold(a, kernel_size, stride):
|
36 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
37 |
+
with K the kernel size, by extracting frames with the given stride.
|
38 |
+
|
39 |
+
This will pad the input so that `F = ceil(T / K)`.
|
40 |
+
|
41 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
42 |
+
"""
|
43 |
+
*shape, length = a.shape
|
44 |
+
n_frames = math.ceil(length / stride)
|
45 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
46 |
+
a = F.pad(a, (0, tgt_length - length))
|
47 |
+
strides = list(a.stride())
|
48 |
+
assert strides[-1] == 1, 'data should be contiguous'
|
49 |
+
strides = strides[:-1] + [stride, 1]
|
50 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
51 |
+
|
52 |
+
|
53 |
+
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
|
54 |
+
"""
|
55 |
+
Center trim `tensor` with respect to `reference`, along the last dimension.
|
56 |
+
`reference` can also be a number, representing the length to trim to.
|
57 |
+
If the size difference != 0 mod 2, the extra sample is removed on the right side.
|
58 |
+
"""
|
59 |
+
ref_size: int
|
60 |
+
if isinstance(reference, torch.Tensor):
|
61 |
+
ref_size = reference.size(-1)
|
62 |
+
else:
|
63 |
+
ref_size = reference
|
64 |
+
delta = tensor.size(-1) - ref_size
|
65 |
+
if delta < 0:
|
66 |
+
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
|
67 |
+
if delta:
|
68 |
+
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
|
69 |
+
return tensor
|
70 |
+
|
71 |
+
|
72 |
+
def pull_metric(history: tp.List[dict], name: str):
|
73 |
+
out = []
|
74 |
+
for metrics in history:
|
75 |
+
metric = metrics
|
76 |
+
for part in name.split("."):
|
77 |
+
metric = metric[part]
|
78 |
+
out.append(metric)
|
79 |
+
return out
|
80 |
+
|
81 |
+
|
82 |
+
def EMA(beta: float = 1):
|
83 |
+
"""
|
84 |
+
Exponential Moving Average callback.
|
85 |
+
Returns a single function that can be called to repeatidly update the EMA
|
86 |
+
with a dict of metrics. The callback will return
|
87 |
+
the new averaged dict of metrics.
|
88 |
+
|
89 |
+
Note that for `beta=1`, this is just plain averaging.
|
90 |
+
"""
|
91 |
+
fix: tp.Dict[str, float] = defaultdict(float)
|
92 |
+
total: tp.Dict[str, float] = defaultdict(float)
|
93 |
+
|
94 |
+
def _update(metrics: dict, weight: float = 1) -> dict:
|
95 |
+
nonlocal total, fix
|
96 |
+
for key, value in metrics.items():
|
97 |
+
total[key] = total[key] * beta + weight * float(value)
|
98 |
+
fix[key] = fix[key] * beta + weight
|
99 |
+
return {key: tot / fix[key] for key, tot in total.items()}
|
100 |
+
return _update
|
101 |
+
|
102 |
+
|
103 |
+
def sizeof_fmt(num: float, suffix: str = 'B'):
|
104 |
+
"""
|
105 |
+
Given `num` bytes, return human readable size.
|
106 |
+
Taken from https://stackoverflow.com/a/1094933
|
107 |
+
"""
|
108 |
+
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
|
109 |
+
if abs(num) < 1024.0:
|
110 |
+
return "%3.1f%s%s" % (num, unit, suffix)
|
111 |
+
num /= 1024.0
|
112 |
+
return "%.1f%s%s" % (num, 'Yi', suffix)
|
113 |
+
|
114 |
+
|
115 |
+
@contextmanager
|
116 |
+
def temp_filenames(count: int, delete=True):
|
117 |
+
names = []
|
118 |
+
try:
|
119 |
+
for _ in range(count):
|
120 |
+
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
121 |
+
yield names
|
122 |
+
finally:
|
123 |
+
if delete:
|
124 |
+
for name in names:
|
125 |
+
os.unlink(name)
|
126 |
+
|
127 |
+
def average_metric(metric, count=1.):
|
128 |
+
"""
|
129 |
+
Average `metric` which should be a float across all hosts. `count` should be
|
130 |
+
the weight for this particular host (i.e. number of examples).
|
131 |
+
"""
|
132 |
+
metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda')
|
133 |
+
distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
|
134 |
+
return metric[1].item() / metric[0].item()
|
135 |
+
|
136 |
+
|
137 |
+
def free_port(host='', low=20000, high=40000):
|
138 |
+
"""
|
139 |
+
Return a port number that is most likely free.
|
140 |
+
This could suffer from a race condition although
|
141 |
+
it should be quite rare.
|
142 |
+
"""
|
143 |
+
sock = socket.socket()
|
144 |
+
while True:
|
145 |
+
port = random.randint(low, high)
|
146 |
+
try:
|
147 |
+
sock.bind((host, port))
|
148 |
+
except OSError as error:
|
149 |
+
if error.errno == errno.EADDRINUSE:
|
150 |
+
continue
|
151 |
+
raise
|
152 |
+
return port
|
153 |
+
|
154 |
+
|
155 |
+
def sizeof_fmt(num, suffix='B'):
|
156 |
+
"""
|
157 |
+
Given `num` bytes, return human readable size.
|
158 |
+
Taken from https://stackoverflow.com/a/1094933
|
159 |
+
"""
|
160 |
+
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
|
161 |
+
if abs(num) < 1024.0:
|
162 |
+
return "%3.1f%s%s" % (num, unit, suffix)
|
163 |
+
num /= 1024.0
|
164 |
+
return "%.1f%s%s" % (num, 'Yi', suffix)
|
165 |
+
|
166 |
+
|
167 |
+
def human_seconds(seconds, display='.2f'):
|
168 |
+
"""
|
169 |
+
Given `seconds` seconds, return human readable duration.
|
170 |
+
"""
|
171 |
+
value = seconds * 1e6
|
172 |
+
ratios = [1e3, 1e3, 60, 60, 24]
|
173 |
+
names = ['us', 'ms', 's', 'min', 'hrs', 'days']
|
174 |
+
last = names.pop(0)
|
175 |
+
for name, ratio in zip(names, ratios):
|
176 |
+
if value / ratio < 0.3:
|
177 |
+
break
|
178 |
+
value /= ratio
|
179 |
+
last = name
|
180 |
+
return f"{format(value, display)} {last}"
|
181 |
+
|
182 |
+
|
183 |
+
class TensorChunk:
|
184 |
+
def __init__(self, tensor, offset=0, length=None):
|
185 |
+
total_length = tensor.shape[-1]
|
186 |
+
assert offset >= 0
|
187 |
+
assert offset < total_length
|
188 |
+
|
189 |
+
if length is None:
|
190 |
+
length = total_length - offset
|
191 |
+
else:
|
192 |
+
length = min(total_length - offset, length)
|
193 |
+
|
194 |
+
self.tensor = tensor
|
195 |
+
self.offset = offset
|
196 |
+
self.length = length
|
197 |
+
self.device = tensor.device
|
198 |
+
|
199 |
+
@property
|
200 |
+
def shape(self):
|
201 |
+
shape = list(self.tensor.shape)
|
202 |
+
shape[-1] = self.length
|
203 |
+
return shape
|
204 |
+
|
205 |
+
def padded(self, target_length):
|
206 |
+
delta = target_length - self.length
|
207 |
+
total_length = self.tensor.shape[-1]
|
208 |
+
assert delta >= 0
|
209 |
+
|
210 |
+
start = self.offset - delta // 2
|
211 |
+
end = start + target_length
|
212 |
+
|
213 |
+
correct_start = max(0, start)
|
214 |
+
correct_end = min(total_length, end)
|
215 |
+
|
216 |
+
pad_left = correct_start - start
|
217 |
+
pad_right = end - correct_end
|
218 |
+
|
219 |
+
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
220 |
+
assert out.shape[-1] == target_length
|
221 |
+
return out
|
222 |
+
|
223 |
+
|
224 |
+
def tensor_chunk(tensor_or_chunk):
|
225 |
+
if isinstance(tensor_or_chunk, TensorChunk):
|
226 |
+
return tensor_or_chunk
|
227 |
+
else:
|
228 |
+
assert isinstance(tensor_or_chunk, th.Tensor)
|
229 |
+
return TensorChunk(tensor_or_chunk)
|
230 |
+
|
231 |
+
|
232 |
+
def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None):
|
233 |
+
"""
|
234 |
+
Apply model to a given mixture.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
238 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
239 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
240 |
+
and improves SDR by up to 0.2 points.
|
241 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
242 |
+
and predictions will be performed individually on each and concatenated.
|
243 |
+
Useful for model with large memory footprint like Tasnet.
|
244 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
245 |
+
"""
|
246 |
+
|
247 |
+
channels, length = mix.size()
|
248 |
+
device = mix.device
|
249 |
+
progress_value = 0
|
250 |
+
|
251 |
+
if split:
|
252 |
+
out = th.zeros(4, channels, length, device=device)
|
253 |
+
shift = model.samplerate * 10
|
254 |
+
offsets = range(0, length, shift)
|
255 |
+
scale = 10
|
256 |
+
if progress:
|
257 |
+
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
|
258 |
+
for offset in offsets:
|
259 |
+
chunk = mix[..., offset:offset + shift]
|
260 |
+
if set_progress_bar:
|
261 |
+
progress_value += 1
|
262 |
+
set_progress_bar(0.1, (0.8/len(offsets)*progress_value))
|
263 |
+
chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
|
264 |
+
else:
|
265 |
+
chunk_out = apply_model_v1(model, chunk, shifts=shifts)
|
266 |
+
out[..., offset:offset + shift] = chunk_out
|
267 |
+
offset += shift
|
268 |
+
return out
|
269 |
+
elif shifts:
|
270 |
+
max_shift = int(model.samplerate / 2)
|
271 |
+
mix = F.pad(mix, (max_shift, max_shift))
|
272 |
+
offsets = list(range(max_shift))
|
273 |
+
random.shuffle(offsets)
|
274 |
+
out = 0
|
275 |
+
for offset in offsets[:shifts]:
|
276 |
+
shifted = mix[..., offset:offset + length + max_shift]
|
277 |
+
if set_progress_bar:
|
278 |
+
shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar)
|
279 |
+
else:
|
280 |
+
shifted_out = apply_model_v1(model, shifted)
|
281 |
+
out += shifted_out[..., max_shift - offset:max_shift - offset + length]
|
282 |
+
out /= shifts
|
283 |
+
return out
|
284 |
+
else:
|
285 |
+
valid_length = model.valid_length(length)
|
286 |
+
delta = valid_length - length
|
287 |
+
padded = F.pad(mix, (delta // 2, delta - delta // 2))
|
288 |
+
with th.no_grad():
|
289 |
+
out = model(padded.unsqueeze(0))[0]
|
290 |
+
return center_trim(out, mix)
|
291 |
+
|
292 |
+
def apply_model_v2(model, mix, shifts=None, split=False,
|
293 |
+
overlap=0.25, transition_power=1., progress=False, set_progress_bar=None):
|
294 |
+
"""
|
295 |
+
Apply model to a given mixture.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
299 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
300 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
301 |
+
and improves SDR by up to 0.2 points.
|
302 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
303 |
+
and predictions will be performed individually on each and concatenated.
|
304 |
+
Useful for model with large memory footprint like Tasnet.
|
305 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
306 |
+
"""
|
307 |
+
|
308 |
+
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
309 |
+
device = mix.device
|
310 |
+
channels, length = mix.shape
|
311 |
+
progress_value = 0
|
312 |
+
|
313 |
+
if split:
|
314 |
+
out = th.zeros(len(model.sources), channels, length, device=device)
|
315 |
+
sum_weight = th.zeros(length, device=device)
|
316 |
+
segment = model.segment_length
|
317 |
+
stride = int((1 - overlap) * segment)
|
318 |
+
offsets = range(0, length, stride)
|
319 |
+
scale = stride / model.samplerate
|
320 |
+
if progress:
|
321 |
+
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
|
322 |
+
# We start from a triangle shaped weight, with maximal weight in the middle
|
323 |
+
# of the segment. Then we normalize and take to the power `transition_power`.
|
324 |
+
# Large values of transition power will lead to sharper transitions.
|
325 |
+
weight = th.cat([th.arange(1, segment // 2 + 1),
|
326 |
+
th.arange(segment - segment // 2, 0, -1)]).to(device)
|
327 |
+
assert len(weight) == segment
|
328 |
+
# If the overlap < 50%, this will translate to linear transition when
|
329 |
+
# transition_power is 1.
|
330 |
+
weight = (weight / weight.max())**transition_power
|
331 |
+
for offset in offsets:
|
332 |
+
chunk = TensorChunk(mix, offset, segment)
|
333 |
+
if set_progress_bar:
|
334 |
+
progress_value += 1
|
335 |
+
set_progress_bar(0.1, (0.8/len(offsets)*progress_value))
|
336 |
+
chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
|
337 |
+
else:
|
338 |
+
chunk_out = apply_model_v2(model, chunk, shifts=shifts)
|
339 |
+
chunk_length = chunk_out.shape[-1]
|
340 |
+
out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out
|
341 |
+
sum_weight[offset:offset + segment] += weight[:chunk_length]
|
342 |
+
offset += segment
|
343 |
+
assert sum_weight.min() > 0
|
344 |
+
out /= sum_weight
|
345 |
+
return out
|
346 |
+
elif shifts:
|
347 |
+
max_shift = int(0.5 * model.samplerate)
|
348 |
+
mix = tensor_chunk(mix)
|
349 |
+
padded_mix = mix.padded(length + 2 * max_shift)
|
350 |
+
out = 0
|
351 |
+
for _ in range(shifts):
|
352 |
+
offset = random.randint(0, max_shift)
|
353 |
+
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
354 |
+
|
355 |
+
if set_progress_bar:
|
356 |
+
progress_value += 1
|
357 |
+
shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar)
|
358 |
+
else:
|
359 |
+
shifted_out = apply_model_v2(model, shifted)
|
360 |
+
out += shifted_out[..., max_shift - offset:]
|
361 |
+
out /= shifts
|
362 |
+
return out
|
363 |
+
else:
|
364 |
+
valid_length = model.valid_length(length)
|
365 |
+
mix = tensor_chunk(mix)
|
366 |
+
padded_mix = mix.padded(valid_length)
|
367 |
+
with th.no_grad():
|
368 |
+
out = model(padded_mix.unsqueeze(0))[0]
|
369 |
+
return center_trim(out, length)
|
370 |
+
|
371 |
+
|
372 |
+
@contextmanager
|
373 |
+
def temp_filenames(count, delete=True):
|
374 |
+
names = []
|
375 |
+
try:
|
376 |
+
for _ in range(count):
|
377 |
+
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
378 |
+
yield names
|
379 |
+
finally:
|
380 |
+
if delete:
|
381 |
+
for name in names:
|
382 |
+
os.unlink(name)
|
383 |
+
|
384 |
+
|
385 |
+
def get_quantizer(model, args, optimizer=None):
|
386 |
+
quantizer = None
|
387 |
+
if args.diffq:
|
388 |
+
quantizer = DiffQuantizer(
|
389 |
+
model, min_size=args.q_min_size, group_size=8)
|
390 |
+
if optimizer is not None:
|
391 |
+
quantizer.setup_optimizer(optimizer)
|
392 |
+
elif args.qat:
|
393 |
+
quantizer = UniformQuantizer(
|
394 |
+
model, bits=args.qat, min_size=args.q_min_size)
|
395 |
+
return quantizer
|
396 |
+
|
397 |
+
|
398 |
+
def load_model(path, strict=False):
|
399 |
+
with warnings.catch_warnings():
|
400 |
+
warnings.simplefilter("ignore")
|
401 |
+
load_from = path
|
402 |
+
package = th.load(load_from, 'cpu')
|
403 |
+
|
404 |
+
klass = package["klass"]
|
405 |
+
args = package["args"]
|
406 |
+
kwargs = package["kwargs"]
|
407 |
+
|
408 |
+
if strict:
|
409 |
+
model = klass(*args, **kwargs)
|
410 |
+
else:
|
411 |
+
sig = inspect.signature(klass)
|
412 |
+
for key in list(kwargs):
|
413 |
+
if key not in sig.parameters:
|
414 |
+
warnings.warn("Dropping inexistant parameter " + key)
|
415 |
+
del kwargs[key]
|
416 |
+
model = klass(*args, **kwargs)
|
417 |
+
|
418 |
+
state = package["state"]
|
419 |
+
training_args = package["training_args"]
|
420 |
+
quantizer = get_quantizer(model, training_args)
|
421 |
+
|
422 |
+
set_state(model, quantizer, state)
|
423 |
+
return model
|
424 |
+
|
425 |
+
|
426 |
+
def get_state(model, quantizer):
|
427 |
+
if quantizer is None:
|
428 |
+
state = {k: p.data.to('cpu') for k, p in model.state_dict().items()}
|
429 |
+
else:
|
430 |
+
state = quantizer.get_quantized_state()
|
431 |
+
buf = io.BytesIO()
|
432 |
+
th.save(state, buf)
|
433 |
+
state = {'compressed': zlib.compress(buf.getvalue())}
|
434 |
+
return state
|
435 |
+
|
436 |
+
|
437 |
+
def set_state(model, quantizer, state):
|
438 |
+
if quantizer is None:
|
439 |
+
model.load_state_dict(state)
|
440 |
+
else:
|
441 |
+
buf = io.BytesIO(zlib.decompress(state["compressed"]))
|
442 |
+
state = th.load(buf, "cpu")
|
443 |
+
quantizer.restore_quantized_state(state)
|
444 |
+
|
445 |
+
return state
|
446 |
+
|
447 |
+
|
448 |
+
def save_state(state, path):
|
449 |
+
buf = io.BytesIO()
|
450 |
+
th.save(state, buf)
|
451 |
+
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
|
452 |
+
|
453 |
+
path = path.parent / (path.stem + "-" + sig + path.suffix)
|
454 |
+
path.write_bytes(buf.getvalue())
|
455 |
+
|
456 |
+
|
457 |
+
def save_model(model, quantizer, training_args, path):
|
458 |
+
args, kwargs = model._init_args_kwargs
|
459 |
+
klass = model.__class__
|
460 |
+
|
461 |
+
state = get_state(model, quantizer)
|
462 |
+
|
463 |
+
save_to = path
|
464 |
+
package = {
|
465 |
+
'klass': klass,
|
466 |
+
'args': args,
|
467 |
+
'kwargs': kwargs,
|
468 |
+
'state': state,
|
469 |
+
'training_args': training_args,
|
470 |
+
}
|
471 |
+
th.save(package, save_to)
|
472 |
+
|
473 |
+
|
474 |
+
def capture_init(init):
|
475 |
+
@functools.wraps(init)
|
476 |
+
def __init__(self, *args, **kwargs):
|
477 |
+
self._init_args_kwargs = (args, kwargs)
|
478 |
+
init(self, *args, **kwargs)
|
479 |
+
|
480 |
+
return __init__
|
481 |
+
|
482 |
+
class DummyPoolExecutor:
|
483 |
+
class DummyResult:
|
484 |
+
def __init__(self, func, *args, **kwargs):
|
485 |
+
self.func = func
|
486 |
+
self.args = args
|
487 |
+
self.kwargs = kwargs
|
488 |
+
|
489 |
+
def result(self):
|
490 |
+
return self.func(*self.args, **self.kwargs)
|
491 |
+
|
492 |
+
def __init__(self, workers=0):
|
493 |
+
pass
|
494 |
+
|
495 |
+
def submit(self, func, *args, **kwargs):
|
496 |
+
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
497 |
+
|
498 |
+
def __enter__(self):
|
499 |
+
return self
|
500 |
+
|
501 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
502 |
+
return
|
gui_data/__pycache__/app_size_values.cpython-310.pyc
ADDED
Binary file (8.52 kB). View file
|
|
gui_data/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (63.6 kB). View file
|
|
gui_data/__pycache__/error_handling.cpython-310.pyc
ADDED
Binary file (5.67 kB). View file
|
|
gui_data/__pycache__/old_data_check.cpython-310.pyc
ADDED
Binary file (1.01 kB). View file
|
|
gui_data/app_size_values.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import platform
|
3 |
+
from screeninfo import get_monitors
|
4 |
+
from PIL import Image
|
5 |
+
from PIL import ImageTk
|
6 |
+
|
7 |
+
OPERATING_SYSTEM = platform.system()
|
8 |
+
|
9 |
+
def get_screen_height():
|
10 |
+
monitors = get_monitors()
|
11 |
+
if len(monitors) == 0:
|
12 |
+
raise Exception("Failed to get screen height")
|
13 |
+
return monitors[0].height, monitors[0].width
|
14 |
+
|
15 |
+
def scale_values(value):
|
16 |
+
if not SCALE_WIN_SIZE == 1920:
|
17 |
+
ratio = SCALE_WIN_SIZE/1920 # Approx. 1.3333 for 2K
|
18 |
+
return value * ratio
|
19 |
+
else:
|
20 |
+
return value
|
21 |
+
|
22 |
+
SCREEN_HIGHT, SCREEN_WIDTH = get_screen_height()
|
23 |
+
SCALE_WIN_SIZE = 1920
|
24 |
+
|
25 |
+
SCREEN_SIZE_VALUES = {
|
26 |
+
"normal": {
|
27 |
+
"credits_img":(100, 100),
|
28 |
+
## App Size
|
29 |
+
'IMAGE_HEIGHT': 140,
|
30 |
+
'FILEPATHS_HEIGHT': 75,
|
31 |
+
'OPTIONS_HEIGHT': 262,
|
32 |
+
'CONVERSIONBUTTON_HEIGHT': 30,
|
33 |
+
'COMMAND_HEIGHT': 141,
|
34 |
+
'PROGRESS_HEIGHT': 25,
|
35 |
+
'PADDING': 7,
|
36 |
+
'WIDTH': 680
|
37 |
+
},
|
38 |
+
"small": {
|
39 |
+
"credits_img":(50, 50),
|
40 |
+
'IMAGE_HEIGHT': 140,
|
41 |
+
'FILEPATHS_HEIGHT': 75,
|
42 |
+
'OPTIONS_HEIGHT': 262,
|
43 |
+
'CONVERSIONBUTTON_HEIGHT': 30,
|
44 |
+
'COMMAND_HEIGHT': 80,
|
45 |
+
'PROGRESS_HEIGHT': 25,
|
46 |
+
'PADDING': 5,
|
47 |
+
'WIDTH': 680
|
48 |
+
},
|
49 |
+
"medium": {
|
50 |
+
"credits_img":(50, 50),
|
51 |
+
## App Size
|
52 |
+
'IMAGE_HEIGHT': 140,
|
53 |
+
'FILEPATHS_HEIGHT': 75,
|
54 |
+
'OPTIONS_HEIGHT': 262,
|
55 |
+
'CONVERSIONBUTTON_HEIGHT': 30,
|
56 |
+
'COMMAND_HEIGHT': 115,
|
57 |
+
'PROGRESS_HEIGHT': 25,
|
58 |
+
'PADDING': 7,
|
59 |
+
'WIDTH': 680
|
60 |
+
},
|
61 |
+
}
|
62 |
+
|
63 |
+
try:
|
64 |
+
if SCREEN_HIGHT >= 900:
|
65 |
+
determined_size = SCREEN_SIZE_VALUES["normal"]
|
66 |
+
elif SCREEN_HIGHT <= 720:
|
67 |
+
determined_size = SCREEN_SIZE_VALUES["small"]
|
68 |
+
else:
|
69 |
+
determined_size = SCREEN_SIZE_VALUES["medium"]
|
70 |
+
except:
|
71 |
+
determined_size = SCREEN_SIZE_VALUES["normal"]
|
72 |
+
|
73 |
+
image_scale_1, image_scale_2 = 20, 30
|
74 |
+
|
75 |
+
class ImagePath():
|
76 |
+
def __init__(self, base_path):
|
77 |
+
img_path = os.path.join(base_path, 'gui_data', 'img')
|
78 |
+
credits_path = os.path.join(img_path, 'credits.png')
|
79 |
+
donate_path = os.path.join(img_path, 'donate.png')
|
80 |
+
download_path = os.path.join(img_path, 'download.png')
|
81 |
+
efile_path = os.path.join(img_path, 'File.png')
|
82 |
+
help_path = os.path.join(img_path, 'help.png')
|
83 |
+
key_path = os.path.join(img_path, 'key.png')
|
84 |
+
stop_path = os.path.join(img_path, 'stop.png')
|
85 |
+
play_path = os.path.join(img_path, 'play.png')
|
86 |
+
pause_path = os.path.join(img_path, 'pause.png')
|
87 |
+
up_img_path = os.path.join(img_path, "up.png")
|
88 |
+
down_img_path = os.path.join(img_path, "down.png")
|
89 |
+
left_img_path = os.path.join(img_path, "left.png")
|
90 |
+
right_img_path = os.path.join(img_path, "right.png")
|
91 |
+
clear_img_path = os.path.join(img_path, "clear.png")
|
92 |
+
copy_img_path = os.path.join(img_path, "copy.png")
|
93 |
+
self.banner_path = os.path.join(img_path, 'UVR-banner.png')
|
94 |
+
|
95 |
+
self.efile_img = self.open_image(path=efile_path,size=(image_scale_1, image_scale_1))
|
96 |
+
self.stop_img = self.open_image(path=stop_path, size=(image_scale_1, image_scale_1))
|
97 |
+
self.play_img = self.open_image(path=play_path, size=(image_scale_1, image_scale_1))
|
98 |
+
self.pause_img = self.open_image(path=pause_path, size=(image_scale_1, image_scale_1))
|
99 |
+
self.help_img = self.open_image(path=help_path, size=(image_scale_1, image_scale_1))
|
100 |
+
self.download_img = self.open_image(path=download_path, size=(image_scale_2, image_scale_2))
|
101 |
+
self.donate_img = self.open_image(path=donate_path, size=(image_scale_2, image_scale_2))
|
102 |
+
self.key_img = self.open_image(path=key_path, size=(image_scale_2, image_scale_2))
|
103 |
+
self.up_img = self.open_image(path=up_img_path, size=(image_scale_2, image_scale_2))
|
104 |
+
self.down_img = self.open_image(path=down_img_path, size=(image_scale_2, image_scale_2))
|
105 |
+
self.left_img = self.open_image(path=left_img_path, size=(image_scale_2, image_scale_2))
|
106 |
+
self.right_img = self.open_image(path=right_img_path, size=(image_scale_2, image_scale_2))
|
107 |
+
self.clear_img = self.open_image(path=clear_img_path, size=(image_scale_2, image_scale_2))
|
108 |
+
self.copy_img = self.open_image(path=copy_img_path, size=(image_scale_2, image_scale_2))
|
109 |
+
self.credits_img = self.open_image(path=credits_path, size=determined_size["credits_img"])
|
110 |
+
|
111 |
+
def open_image(self, path: str, size: tuple = None, keep_aspect: bool = True, rotate: int = 0) -> ImageTk.PhotoImage:
|
112 |
+
"""
|
113 |
+
Open the image on the path and apply given settings\n
|
114 |
+
Paramaters:
|
115 |
+
path(str):
|
116 |
+
Absolute path of the image
|
117 |
+
size(tuple):
|
118 |
+
first value - width
|
119 |
+
second value - height
|
120 |
+
keep_aspect(bool):
|
121 |
+
keep aspect ratio of image and resize
|
122 |
+
to maximum possible width and height
|
123 |
+
(maxima are given by size)
|
124 |
+
rotate(int):
|
125 |
+
clockwise rotation of image
|
126 |
+
Returns(ImageTk.PhotoImage):
|
127 |
+
Image of path
|
128 |
+
"""
|
129 |
+
img = Image.open(path).convert(mode='RGBA')
|
130 |
+
ratio = img.height/img.width
|
131 |
+
img = img.rotate(angle=-rotate)
|
132 |
+
if size is not None:
|
133 |
+
size = (int(size[0]), int(size[1]))
|
134 |
+
if keep_aspect:
|
135 |
+
img = img.resize((size[0], int(size[0] * ratio)), Image.ANTIALIAS)
|
136 |
+
else:
|
137 |
+
img = img.resize(size, Image.ANTIALIAS)
|
138 |
+
|
139 |
+
return ImageTk.PhotoImage(img)
|
140 |
+
|
141 |
+
#All Sizes Below Calibrated to 1080p!
|
142 |
+
|
143 |
+
if OPERATING_SYSTEM=="Darwin":
|
144 |
+
FONT_SIZE_F1 = 13
|
145 |
+
FONT_SIZE_F2 = 11
|
146 |
+
FONT_SIZE_F3 = 12
|
147 |
+
FONT_SIZE_0 = 9
|
148 |
+
FONT_SIZE_1 = 11
|
149 |
+
FONT_SIZE_2 = 12
|
150 |
+
FONT_SIZE_3 = 13
|
151 |
+
FONT_SIZE_4 = 14
|
152 |
+
FONT_SIZE_5 = 15
|
153 |
+
FONT_SIZE_6 = 17
|
154 |
+
HELP_HINT_CHECKBOX_WIDTH = 13
|
155 |
+
MDX_CHECKBOXS_WIDTH = 14
|
156 |
+
VR_CHECKBOXS_WIDTH = 14
|
157 |
+
ENSEMBLE_CHECKBOXS_WIDTH = 18
|
158 |
+
DEMUCS_CHECKBOXS_WIDTH = 14
|
159 |
+
DEMUCS_PRE_CHECKBOXS_WIDTH = 20
|
160 |
+
GEN_SETTINGS_WIDTH = 17
|
161 |
+
MENU_COMBOBOX_WIDTH = 16
|
162 |
+
MENU_OPTION_WIDTH = 12
|
163 |
+
READ_ONLY_COMBO_WIDTH = 35
|
164 |
+
SETTINGS_BUT_WIDTH = 19
|
165 |
+
VR_BUT_WIDTH = 16
|
166 |
+
SET_MENUS_CHECK_WIDTH = 12
|
167 |
+
COMBO_WIDTH = 14
|
168 |
+
SET_VOC_SPLIT_CHECK_WIDTH = 21
|
169 |
+
elif OPERATING_SYSTEM=="Linux":
|
170 |
+
HELP_HINT_CHECKBOX_WIDTH = 15
|
171 |
+
MDX_CHECKBOXS_WIDTH = 16
|
172 |
+
VR_CHECKBOXS_WIDTH = 16
|
173 |
+
ENSEMBLE_CHECKBOXS_WIDTH = 20
|
174 |
+
DEMUCS_CHECKBOXS_WIDTH = 16
|
175 |
+
DEMUCS_PRE_CHECKBOXS_WIDTH = 24
|
176 |
+
GEN_SETTINGS_WIDTH = 20
|
177 |
+
MENU_COMBOBOX_WIDTH = 18
|
178 |
+
MENU_OPTION_WIDTH = 12
|
179 |
+
READ_ONLY_COMBO_WIDTH = 40
|
180 |
+
SETTINGS_BUT_WIDTH = 23
|
181 |
+
VR_BUT_WIDTH = 18
|
182 |
+
SET_MENUS_CHECK_WIDTH = 13
|
183 |
+
COMBO_WIDTH = 16
|
184 |
+
SET_VOC_SPLIT_CHECK_WIDTH = 25
|
185 |
+
FONT_SIZE_F1 = 10
|
186 |
+
FONT_SIZE_F2 = 8
|
187 |
+
FONT_SIZE_F3 = 9
|
188 |
+
FONT_SIZE_0 = 7
|
189 |
+
FONT_SIZE_1 = 8
|
190 |
+
FONT_SIZE_2 = 9
|
191 |
+
FONT_SIZE_3 = 10
|
192 |
+
FONT_SIZE_4 = 11
|
193 |
+
FONT_SIZE_5 = 13
|
194 |
+
FONT_SIZE_6 = 15
|
195 |
+
elif OPERATING_SYSTEM=="Windows":
|
196 |
+
HELP_HINT_CHECKBOX_WIDTH = 15
|
197 |
+
MDX_CHECKBOXS_WIDTH = 14
|
198 |
+
VR_CHECKBOXS_WIDTH = 14
|
199 |
+
ENSEMBLE_CHECKBOXS_WIDTH = 20
|
200 |
+
DEMUCS_CHECKBOXS_WIDTH = 14
|
201 |
+
DEMUCS_PRE_CHECKBOXS_WIDTH = 20
|
202 |
+
GEN_SETTINGS_WIDTH = 18
|
203 |
+
MENU_COMBOBOX_WIDTH = 16
|
204 |
+
MENU_OPTION_WIDTH = 12
|
205 |
+
READ_ONLY_COMBO_WIDTH = 35
|
206 |
+
SETTINGS_BUT_WIDTH = 20
|
207 |
+
VR_BUT_WIDTH = 16
|
208 |
+
SET_MENUS_CHECK_WIDTH = 13
|
209 |
+
COMBO_WIDTH = 14
|
210 |
+
SET_VOC_SPLIT_CHECK_WIDTH = 23
|
211 |
+
FONT_SIZE_F1 = 10
|
212 |
+
FONT_SIZE_F2 = 8
|
213 |
+
FONT_SIZE_F3 = 9
|
214 |
+
FONT_SIZE_0 = 7
|
215 |
+
FONT_SIZE_1 = 8
|
216 |
+
FONT_SIZE_2 = 9
|
217 |
+
FONT_SIZE_3 = 10
|
218 |
+
FONT_SIZE_4 = 11
|
219 |
+
FONT_SIZE_5 = 13
|
220 |
+
FONT_SIZE_6 = 15
|
221 |
+
|
222 |
+
#Main Size Values:
|
223 |
+
IMAGE_HEIGHT = determined_size["IMAGE_HEIGHT"]
|
224 |
+
FILEPATHS_HEIGHT = determined_size["FILEPATHS_HEIGHT"]
|
225 |
+
OPTIONS_HEIGHT = determined_size["OPTIONS_HEIGHT"]
|
226 |
+
CONVERSIONBUTTON_HEIGHT = determined_size["CONVERSIONBUTTON_HEIGHT"]
|
227 |
+
COMMAND_HEIGHT = determined_size["COMMAND_HEIGHT"]
|
228 |
+
PROGRESS_HEIGHT = determined_size["PROGRESS_HEIGHT"]
|
229 |
+
PADDING = determined_size["PADDING"]
|
230 |
+
WIDTH = determined_size["WIDTH"]
|
231 |
+
|
232 |
+
# IMAGE_HEIGHT = 140
|
233 |
+
# FILEPATHS_HEIGHT = 75
|
234 |
+
# OPTIONS_HEIGHT = 262
|
235 |
+
# CONVERSIONBUTTON_HEIGHT = 30
|
236 |
+
# COMMAND_HEIGHT = 141
|
237 |
+
# PROGRESS_HEIGHT = 25
|
238 |
+
# PADDING = 7
|
239 |
+
# WIDTH = 680
|
240 |
+
|
241 |
+
MENU_PADDING_1 = 3
|
242 |
+
MENU_PADDING_2 = 10
|
243 |
+
MENU_PADDING_3 = 15
|
244 |
+
MENU_PADDING_4 = 3
|
245 |
+
|
246 |
+
#Main Frame Sizes
|
247 |
+
X_CONVERSION_BUTTON_1080P = 50
|
248 |
+
WIDTH_CONVERSION_BUTTON_1080P = -100
|
249 |
+
HEIGHT_GENERIC_BUTTON_1080P = 35
|
250 |
+
X_STOP_BUTTON_1080P = -10 - 35
|
251 |
+
X_SETTINGS_BUTTON_1080P = -670
|
252 |
+
X_PROGRESSBAR_1080P = 25
|
253 |
+
WIDTH_PROGRESSBAR_1080P = -50
|
254 |
+
X_CONSOLE_FRAME_1080P = 15
|
255 |
+
WIDTH_CONSOLE_FRAME_1080P = -30
|
256 |
+
HO_S = 7
|
257 |
+
|
258 |
+
#File Frame Sizes
|
259 |
+
FILEPATHS_FRAME_X = 10
|
260 |
+
FILEPATHS_FRAME_Y = 155
|
261 |
+
FILEPATHS_FRAME_WIDTH = -20
|
262 |
+
MUSICFILE_BUTTON_X = 0
|
263 |
+
MUSICFILE_BUTTON_Y = 5
|
264 |
+
MUSICFILE_BUTTON_WIDTH = 0
|
265 |
+
MUSICFILE_BUTTON_HEIGHT = -5
|
266 |
+
MUSICFILE_ENTRY_X = 7.5
|
267 |
+
MUSICFILE_ENTRY_WIDTH = -50
|
268 |
+
MUSICFILE_ENTRY_HEIGHT = -5
|
269 |
+
MUSICFILE_OPEN_X = -45
|
270 |
+
MUSICFILE_OPEN_Y = 160
|
271 |
+
MUSICFILE_OPEN_WIDTH = 35
|
272 |
+
MUSICFILE_OPEN_HEIGHT = 33
|
273 |
+
SAVETO_BUTTON_X = 0
|
274 |
+
SAVETO_BUTTON_Y = 5
|
275 |
+
SAVETO_BUTTON_WIDTH = 0
|
276 |
+
SAVETO_BUTTON_HEIGHT = -5
|
277 |
+
SAVETO_ENTRY_X = 7.5
|
278 |
+
OPEN_BUTTON_X = 427.1
|
279 |
+
OPEN_BUTTON_WIDTH = -427.4
|
280 |
+
SAVETO_ENTRY_WIDTH = -50
|
281 |
+
SAVETO_ENTRY_HEIGHT = -5
|
282 |
+
SAVETO_OPEN_X = -45
|
283 |
+
SAVETO_OPEN_Y = 197.5
|
284 |
+
SAVETO_OPEN_WIDTH = 35
|
285 |
+
SAVETO_OPEN_HEIGHT = 32
|
286 |
+
|
287 |
+
#Main Option menu
|
288 |
+
OPTIONS_FRAME_X = 10
|
289 |
+
OPTIONS_FRAME_Y = 250
|
290 |
+
OPTIONS_FRAME_WIDTH = -20
|
291 |
+
FILEONE_LABEL_X = -28
|
292 |
+
FILEONE_LABEL_WIDTH = -38
|
293 |
+
FILETWO_LABEL_X = -32
|
294 |
+
FILETWO_LABEL_WIDTH = -20
|
295 |
+
TIME_WINDOW_LABEL_X = -43
|
296 |
+
TIME_WINDOW_LABEL_WIDTH = 0
|
297 |
+
INTRO_ANALYSIS_LABEL_X = -83
|
298 |
+
INTRO_ANALYSIS_LABEL_WIDTH = -50
|
299 |
+
INTRO_ANALYSIS_OPTION_X = -68
|
300 |
+
DB_ANALYSIS_LABEL_X = 62
|
301 |
+
DB_ANALYSIS_LABEL_WIDTH = -34
|
302 |
+
DB_ANALYSIS_OPTION_X = 86
|
303 |
+
WAV_TYPE_SET_LABEL_X = -43
|
304 |
+
WAV_TYPE_SET_LABEL_WIDTH = 0
|
305 |
+
ENTRY_WIDTH = 222
|
306 |
+
|
307 |
+
# Constants for the ensemble_listbox_Frame
|
308 |
+
ENSEMBLE_LISTBOX_FRAME_X = -25
|
309 |
+
ENSEMBLE_LISTBOX_FRAME_Y = -20
|
310 |
+
ENSEMBLE_LISTBOX_FRAME_WIDTH = 0
|
311 |
+
ENSEMBLE_LISTBOX_FRAME_HEIGHT = 67
|
312 |
+
|
313 |
+
# Constants for the ensemble_listbox_scroll
|
314 |
+
ENSEMBLE_LISTBOX_SCROLL_X = 195
|
315 |
+
ENSEMBLE_LISTBOX_SCROLL_Y = -20
|
316 |
+
ENSEMBLE_LISTBOX_SCROLL_WIDTH = -48
|
317 |
+
ENSEMBLE_LISTBOX_SCROLL_HEIGHT = 69
|
318 |
+
|
319 |
+
# Constants for Radio Buttons
|
320 |
+
RADIOBUTTON_X_WAV = 457
|
321 |
+
RADIOBUTTON_X_FLAC = 300
|
322 |
+
RADIOBUTTON_X_MP3 = 143
|
323 |
+
RADIOBUTTON_Y = -5
|
324 |
+
RADIOBUTTON_WIDTH = 0
|
325 |
+
RADIOBUTTON_HEIGHT = 6
|
326 |
+
MAIN_ROW_Y_1 = -15
|
327 |
+
MAIN_ROW_Y_2 = -17
|
328 |
+
MAIN_ROW_X_1 = -4
|
329 |
+
MAIN_ROW_X_2 = 21
|
330 |
+
MAIN_ROW_2_Y_1 = -15
|
331 |
+
MAIN_ROW_2_Y_2 = -17
|
332 |
+
MAIN_ROW_2_X_1 = -28
|
333 |
+
MAIN_ROW_2_X_2 = 1
|
334 |
+
LOW_MENU_Y_1 = 18
|
335 |
+
LOW_MENU_Y_2 = 16
|
336 |
+
SUB_ENT_ROW_X = -2
|
337 |
+
MAIN_ROW_WIDTH = -53
|
338 |
+
MAIN_ROW_ALIGN_WIDTH = -86
|
339 |
+
CHECK_BOX_Y = 0
|
340 |
+
CHECK_BOX_X = 20
|
341 |
+
CHECK_BOX_WIDTH = -49
|
342 |
+
CHECK_BOX_HEIGHT = 2
|
343 |
+
LEFT_ROW_WIDTH = -10
|
344 |
+
LABEL_HEIGHT = -5
|
345 |
+
OPTION_HEIGHT = 8
|
346 |
+
LABEL_X_OFFSET = -28
|
347 |
+
LABEL_WIDTH = -38
|
348 |
+
ENTRY_WIDTH = 179.5
|
349 |
+
ENTRY_OPEN_BUTT_WIDTH = -185
|
350 |
+
ENTRY_OPEN_BUTT_X_OFF = 405
|
351 |
+
UPDATE_LABEL_WIDTH = 35 if OPERATING_SYSTEM == 'Linux' else 32
|
352 |
+
|
353 |
+
HEIGHT_CONSOLE_FRAME_1080P = COMMAND_HEIGHT + HO_S
|
354 |
+
LOW_MENU_Y = LOW_MENU_Y_1, LOW_MENU_Y_2
|
355 |
+
MAIN_ROW_Y = MAIN_ROW_Y_1, MAIN_ROW_Y_2
|
356 |
+
MAIN_ROW_X = MAIN_ROW_X_1, MAIN_ROW_X_2
|
357 |
+
MAIN_ROW_2_Y = MAIN_ROW_2_Y_1, MAIN_ROW_2_Y_2
|
358 |
+
MAIN_ROW_2_X = MAIN_ROW_2_X_1, MAIN_ROW_2_X_2
|
359 |
+
|
360 |
+
LABEL_Y = MAIN_ROW_Y[0]
|
361 |
+
ENTRY_Y = MAIN_ROW_Y[1]
|
362 |
+
|
363 |
+
BUTTON_Y_1080P = IMAGE_HEIGHT + FILEPATHS_HEIGHT + OPTIONS_HEIGHT - 8 + PADDING*2
|
364 |
+
HEIGHT_PROGRESSBAR_1080P = PROGRESS_HEIGHT
|
365 |
+
Y_OFFSET_PROGRESS_BAR_1080P = IMAGE_HEIGHT + FILEPATHS_HEIGHT + OPTIONS_HEIGHT + CONVERSIONBUTTON_HEIGHT + COMMAND_HEIGHT + PADDING*4
|
366 |
+
Y_OFFSET_CONSOLE_FRAME_1080P = IMAGE_HEIGHT + FILEPATHS_HEIGHT + OPTIONS_HEIGHT + CONVERSIONBUTTON_HEIGHT + PADDING + X_PROGRESSBAR_1080P
|
367 |
+
|
368 |
+
LABEL_Y_OFFSET = MAIN_ROW_Y[0]
|
369 |
+
ENTRY_X_OFFSET = SUB_ENT_ROW_X
|
370 |
+
ENTRY_Y_OFFSET = MAIN_ROW_Y[1]
|
371 |
+
OPTION_WIDTH = MAIN_ROW_ALIGN_WIDTH
|
gui_data/change_log.txt
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Most Recent Changes:
|
2 |
+
|
3 |
+
~ Fixed Download Center model list issue.
|
4 |
+
~ Fixed audio clip in ensemble mode.
|
5 |
+
~ Fixed output model name issue in ensemble mode.
|
6 |
+
~ Added "Batch Mode" for MDX-Net to increase performance.
|
7 |
+
~ Batch Mode is more memory efficient.
|
8 |
+
~ Batch Mode produces the best output, regardless of batch size.
|
9 |
+
~ Added Batch Mode for VR Architecture.
|
10 |
+
~ Added Mixer Mode for Demucs.
|
11 |
+
~ This option may improve separation for some 4-stem models.
|
12 |
+
|
13 |
+
Fixes & Changes going from UVR v5.4 to v5.5:
|
14 |
+
|
15 |
+
~ The progress bar is now fully synced up with every process in the application.
|
16 |
+
~ Fixed low-resolution icon.
|
17 |
+
~ Added the ability to download models manually if the application can't connect
|
18 |
+
to the internet.
|
19 |
+
~ Drag-n-drop is functional across all os platforms.
|
20 |
+
~ Resolved mp3 tag issue in MacOS version.
|
21 |
+
|
22 |
+
Performance:
|
23 |
+
|
24 |
+
~ Model load times are faster.
|
25 |
+
~ Importing/exporting audio files is faster.
|
26 |
+
|
27 |
+
MacOS M1 Notes:
|
28 |
+
|
29 |
+
~ The GPU Conversion checkbox will enable MPS for GPU acceleration. However,
|
30 |
+
only the VR Architecture models are currently compatible with it.
|
31 |
+
|
32 |
+
New Options:
|
33 |
+
|
34 |
+
~ Select Saved Settings option - Allows the user to save the current settings
|
35 |
+
of the whole application. You can also load a saved setting or reset them to
|
36 |
+
the default.
|
37 |
+
~ Right-click menu - Allows for quick access to important options.
|
38 |
+
~ Help Hints option - When enabled, users can hover over options to see a pop-up
|
39 |
+
text that describes that option. The right-clicking option also allows copying
|
40 |
+
the "Help Hint" text.
|
41 |
+
~ Secondary Model Mode - This option is an expanded version of the "Demucs Model"
|
42 |
+
option that was only available to MDX-Net. Except now, this option is available
|
43 |
+
in all three AI Networks and for any stem. Any model can now be Secondary, and
|
44 |
+
the user can choose the amount of influence it has on the final result.
|
45 |
+
~ Robust caching for ensemble mode, allowing for much faster processing times.
|
46 |
+
~ Clicking the "Input" field will pop up a window allowing the user to review the selected audio inputs. Within this menu, users can:
|
47 |
+
~ Remove inputs.
|
48 |
+
~ Verify inputs.
|
49 |
+
~ Create samples of chosen inputs.
|
50 |
+
~ "Sample Mode" option - Allows the user to process only part of a track to sample
|
51 |
+
settings or a model without running a full conversion.
|
52 |
+
~ The number in the parentheses is the current number of seconds the generated
|
53 |
+
sample will be.
|
54 |
+
~ You can choose the number of seconds to extract from the track in the "Additional
|
55 |
+
Settings" menu.
|
56 |
+
|
57 |
+
VR Architecture:
|
58 |
+
|
59 |
+
~ Ability to toggle "High-End Processing."
|
60 |
+
~ Ability to change the post-processing threshold.
|
61 |
+
~ Support for the latest VR architecture
|
62 |
+
~ Crop Size and Batch Size are specifically for models using the latest
|
63 |
+
architecture only.
|
64 |
+
|
65 |
+
MDX-NET:
|
66 |
+
|
67 |
+
~ Denoise Output option results in cleaner results,
|
68 |
+
but the processing time will be longer. This option has replaced Noise Reduction.
|
69 |
+
~ Spectral Inversion option uses spectral inversion techniques for a
|
70 |
+
cleaner secondary stem result. This option may slow down the audio export process.
|
71 |
+
~ Secondary stem now has the same frequency cut-off as the main stem.
|
72 |
+
|
73 |
+
Demucs:
|
74 |
+
|
75 |
+
~ Demucs v4 models are now supported, including the 6-stem model.
|
76 |
+
~ Ability to combine remaining stems instead of inverting selected stem with the
|
77 |
+
mixture only when a user does not select "All Stems".
|
78 |
+
~ A Pre-process model that allows the user to run an inference through a robust
|
79 |
+
vocal or instrumental model and separate the remaining stems from its generated
|
80 |
+
instrumental mix. This option can significantly reduce vocal bleed in other
|
81 |
+
Demucs-generated non-vocal stems.
|
82 |
+
~ The Pre-process model is intended for Demucs separations for all stems except
|
83 |
+
vocals and instrumentals.
|
84 |
+
|
85 |
+
Ensemble Mode:
|
86 |
+
|
87 |
+
~ Ensemble Mode has been extended to include the following:
|
88 |
+
~ Averaging is a new algorithm that averages the final results.
|
89 |
+
~ Unlimited models in the ensemble.
|
90 |
+
~ Ability to save different ensembles.
|
91 |
+
~ Ability to ensemble outputs for all individual stem types.
|
92 |
+
~ Ability to choose unique ensemble algorithms.
|
93 |
+
~ Ability to ensemble all 4 Demucs stems at once.
|
gui_data/complete_chime.wav
ADDED
Binary file (357 kB). View file
|
|
gui_data/constants.py
ADDED
@@ -0,0 +1,1584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
|
3 |
+
#Platform Details
|
4 |
+
OPERATING_SYSTEM = platform.system()
|
5 |
+
SYSTEM_ARCH = platform.platform()
|
6 |
+
SYSTEM_PROC = platform.processor()
|
7 |
+
ARM = 'arm'
|
8 |
+
|
9 |
+
is_macos = False
|
10 |
+
|
11 |
+
CPU = 'cpu'
|
12 |
+
CUDA_DEVICE = 'cuda'
|
13 |
+
DIRECTML_DEVICE = "privateuseone"
|
14 |
+
|
15 |
+
#MAIN_FONT_NAME = "Century Gothic"
|
16 |
+
OPT_SEPARATOR_SAVE = '─'*25
|
17 |
+
BG_COLOR = '#0e0e0f'
|
18 |
+
FG_COLOR = '#13849f'
|
19 |
+
|
20 |
+
#Model Types
|
21 |
+
VR_ARCH_TYPE = 'VR Arc'
|
22 |
+
MDX_ARCH_TYPE = 'MDX-Net'
|
23 |
+
DEMUCS_ARCH_TYPE = 'Demucs'
|
24 |
+
VR_ARCH_PM = 'VR Architecture'
|
25 |
+
ENSEMBLE_MODE = 'Ensemble Mode'
|
26 |
+
ENSEMBLE_STEM_CHECK = 'Ensemble Stem'
|
27 |
+
SECONDARY_MODEL = 'Secondary Model'
|
28 |
+
DEMUCS_6_STEM_MODEL = 'htdemucs_6s'
|
29 |
+
DEFAULT = "Default"
|
30 |
+
ALIGNMENT_TOOL = 'Alignment Tool Options'
|
31 |
+
|
32 |
+
SINGLE_FILE = 'SINGLE_FILE'
|
33 |
+
MULTIPLE_FILE = 'MULTI_FILE'
|
34 |
+
MAIN_MULTIPLE_FILE = 'MAIN_MULTI_FILE'
|
35 |
+
CHOOSE_EXPORT_FIR = 'CHOOSE_EXPORT_FIR'
|
36 |
+
|
37 |
+
DUAL = "dual"
|
38 |
+
FOUR_STEM = "fourstem"
|
39 |
+
ANY_STEM = "Any Stem"
|
40 |
+
|
41 |
+
DEMUCS_V3_ARCH_TYPE = 'Demucs v3'
|
42 |
+
DEMUCS_V4_ARCH_TYPE = 'Demucs v4'
|
43 |
+
DEMUCS_NEWER_ARCH_TYPES = [DEMUCS_V3_ARCH_TYPE, DEMUCS_V4_ARCH_TYPE]
|
44 |
+
|
45 |
+
DEMUCS_V1 = 'v1'
|
46 |
+
DEMUCS_V2 = 'v2'
|
47 |
+
DEMUCS_V3 = 'v3'
|
48 |
+
DEMUCS_V4 = 'v4'
|
49 |
+
|
50 |
+
DEMUCS_V1_TAG = 'v1 | '
|
51 |
+
DEMUCS_V2_TAG = 'v2 | '
|
52 |
+
DEMUCS_V3_TAG = 'v3 | '
|
53 |
+
DEMUCS_V4_TAG = 'v4 | '
|
54 |
+
DEMUCS_NEWER_TAGS = [DEMUCS_V3_TAG, DEMUCS_V4_TAG]
|
55 |
+
|
56 |
+
DEMUCS_VERSION_MAPPER = {
|
57 |
+
DEMUCS_V1:DEMUCS_V1_TAG,
|
58 |
+
DEMUCS_V2:DEMUCS_V2_TAG,
|
59 |
+
DEMUCS_V3:DEMUCS_V3_TAG,
|
60 |
+
DEMUCS_V4:DEMUCS_V4_TAG}
|
61 |
+
|
62 |
+
#Download Center
|
63 |
+
DOWNLOAD_FAILED = 'Download Failed'
|
64 |
+
DOWNLOAD_STOPPED = 'Download Stopped'
|
65 |
+
DOWNLOAD_COMPLETE = 'Download Complete'
|
66 |
+
DOWNLOAD_UPDATE_COMPLETE = 'Update Download Complete'
|
67 |
+
SETTINGS_MENU_EXIT = 'exit'
|
68 |
+
NO_CONNECTION = 'No Internet Connection'
|
69 |
+
VIP_SELECTION = 'VIP:'
|
70 |
+
DEVELOPER_SELECTION = 'VIP:'
|
71 |
+
NO_NEW_MODELS = 'All Available Models Downloaded'
|
72 |
+
ENSEMBLE_PARTITION = ': '
|
73 |
+
NO_MODEL = 'No Model Selected'
|
74 |
+
CHOOSE_MODEL = 'Choose Model'
|
75 |
+
SINGLE_DOWNLOAD = 'Downloading Item 1/1...'
|
76 |
+
DOWNLOADING_ITEM = 'Downloading Item'
|
77 |
+
FILE_EXISTS = 'File already exists!'
|
78 |
+
DOWNLOADING_UPDATE = 'Downloading Update...'
|
79 |
+
DOWNLOAD_MORE = 'Download More Models'
|
80 |
+
IS_KARAOKEE = "is_karaoke"
|
81 |
+
IS_BV_MODEL = "is_bv_model"
|
82 |
+
IS_BV_MODEL_REBAL = "is_bv_model_rebalanced"
|
83 |
+
INPUT_STEM_NAME = 'Input Stem Name'
|
84 |
+
|
85 |
+
#Menu Options
|
86 |
+
|
87 |
+
AUTO_SELECT = 'Auto'
|
88 |
+
|
89 |
+
#LINKS
|
90 |
+
DOWNLOAD_CHECKS = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
|
91 |
+
MDX_MODEL_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data_new.json"
|
92 |
+
VR_MODEL_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/vr_model_data/model_data_new.json"
|
93 |
+
MDX23_CONFIG_CHECKS = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/"
|
94 |
+
BULLETIN_CHECK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/bulletin.txt"
|
95 |
+
|
96 |
+
DEMUCS_MODEL_NAME_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/demucs_model_data/model_name_mapper.json"
|
97 |
+
MDX_MODEL_NAME_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_name_mapper.json"
|
98 |
+
|
99 |
+
DONATE_LINK_BMAC = "https://www.buymeacoffee.com/uvr5"
|
100 |
+
DONATE_LINK_PATREON = "https://www.patreon.com/uvr"
|
101 |
+
|
102 |
+
#DOWNLOAD REPOS
|
103 |
+
NORMAL_REPO = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/"
|
104 |
+
UPDATE_REPO = "https://github.com/TRvlvr/model_repo/releases/download/uvr_update_patches/"
|
105 |
+
|
106 |
+
UPDATE_MAC_ARM_REPO = "https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/Ultimate_Vocal_Remover_v5_6_MacOS_arm64.dmg"
|
107 |
+
UPDATE_MAC_X86_64_REPO = "https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/Ultimate_Vocal_Remover_v5_6_MacOS_x86_64.dmg"
|
108 |
+
UPDATE_LINUX_REPO = "https://github.com/Anjok07/ultimatevocalremovergui#linux-installation"
|
109 |
+
|
110 |
+
ISSUE_LINK = 'https://github.com/Anjok07/ultimatevocalremovergui/issues/new'
|
111 |
+
VIP_REPO = b'\xf3\xc2W\x19\x1foI)\xc2\xa9\xcc\xb67(Z\xf5',\
|
112 |
+
b'gAAAAABjQAIQ-NpNMMxMedpKHHb7ze_nqB05hw0YhbOy3pFzuzDrfqumn8_qvraxEoUpZC5ZXC0gGvfDxFMqyq9VWbYKlA67SUFI_wZB6QoVyGI581vs7kaGfUqlXHIdDS6tQ_U-BfjbEAK9EU_74-R2zXjz8Xzekw=='
|
113 |
+
NO_CODE = 'incorrect_code'
|
114 |
+
|
115 |
+
#Extensions
|
116 |
+
ONNX = '.onnx'
|
117 |
+
CKPT = '.ckpt'
|
118 |
+
CKPT_C = '.ckptc'
|
119 |
+
YAML = '.yaml'
|
120 |
+
PTH = '.pth'
|
121 |
+
TH_EXT = '.th'
|
122 |
+
JSON = '.json'
|
123 |
+
|
124 |
+
#GUI Buttons
|
125 |
+
START_PROCESSING = 'Start Processing'
|
126 |
+
WAIT_PROCESSING = 'Please wait...'
|
127 |
+
STOP_PROCESSING = 'Halting process, please wait...'
|
128 |
+
LOADING_MODELS = 'Loading models...'
|
129 |
+
|
130 |
+
#---Messages and Logs----
|
131 |
+
|
132 |
+
MISSING_MODEL = 'missing'
|
133 |
+
MODEL_PRESENT = 'present'
|
134 |
+
|
135 |
+
ALL_STEMS = 'All Stems'
|
136 |
+
VOCAL_STEM = 'Vocals'
|
137 |
+
INST_STEM = 'Instrumental'
|
138 |
+
OTHER_STEM = 'Other'
|
139 |
+
BASS_STEM = 'Bass'
|
140 |
+
DRUM_STEM = 'Drums'
|
141 |
+
GUITAR_STEM = 'Guitar'
|
142 |
+
PIANO_STEM = 'Piano'
|
143 |
+
SYNTH_STEM = 'Synthesizer'
|
144 |
+
STRINGS_STEM = 'Strings'
|
145 |
+
WOODWINDS_STEM = 'Woodwinds'
|
146 |
+
BRASS_STEM = 'Brass'
|
147 |
+
WIND_INST_STEM = 'Wind Inst'
|
148 |
+
NO_OTHER_STEM = 'No Other'
|
149 |
+
NO_BASS_STEM = 'No Bass'
|
150 |
+
NO_DRUM_STEM = 'No Drums'
|
151 |
+
NO_GUITAR_STEM = 'No Guitar'
|
152 |
+
NO_PIANO_STEM = 'No Piano'
|
153 |
+
NO_SYNTH_STEM = 'No Synthesizer'
|
154 |
+
NO_STRINGS_STEM = 'No Strings'
|
155 |
+
NO_WOODWINDS_STEM = 'No Woodwinds'
|
156 |
+
NO_WIND_INST_STEM = 'No Wind Inst'
|
157 |
+
NO_BRASS_STEM = 'No Brass'
|
158 |
+
PRIMARY_STEM = 'Primary Stem'
|
159 |
+
SECONDARY_STEM = 'Secondary Stem'
|
160 |
+
LEAD_VOCAL_STEM = 'lead_only'
|
161 |
+
BV_VOCAL_STEM = 'backing_only'
|
162 |
+
LEAD_VOCAL_STEM_I = 'with_lead_vocals'
|
163 |
+
BV_VOCAL_STEM_I = 'with_backing_vocals'
|
164 |
+
LEAD_VOCAL_STEM_LABEL = 'Lead Vocals'
|
165 |
+
BV_VOCAL_STEM_LABEL = 'Backing Vocals'
|
166 |
+
|
167 |
+
VOCAL_STEM_ONLY = f'{VOCAL_STEM} Only'
|
168 |
+
INST_STEM_ONLY = f'{INST_STEM} Only'
|
169 |
+
PRIMARY_STEM_ONLY = f'{PRIMARY_STEM} Only'
|
170 |
+
|
171 |
+
IS_SAVE_INST_ONLY = f'save_only_inst'
|
172 |
+
IS_SAVE_VOC_ONLY = f'save_only_voc'
|
173 |
+
|
174 |
+
DEVERB_MAPPER = {'Main Vocals Only':VOCAL_STEM,
|
175 |
+
'Lead Vocals Only':LEAD_VOCAL_STEM_LABEL,
|
176 |
+
'Backing Vocals Only':BV_VOCAL_STEM_LABEL,
|
177 |
+
'All Vocal Types':'ALL'}
|
178 |
+
|
179 |
+
BALANCE_VALUES = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
180 |
+
|
181 |
+
#Other Constants
|
182 |
+
DEMUCS_2_SOURCE = ["instrumental", "vocals"]
|
183 |
+
DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"]
|
184 |
+
|
185 |
+
DEMUCS_2_SOURCE_MAPPER = {
|
186 |
+
INST_STEM: 0,
|
187 |
+
VOCAL_STEM: 1}
|
188 |
+
|
189 |
+
DEMUCS_4_SOURCE_MAPPER = {
|
190 |
+
BASS_STEM: 0,
|
191 |
+
DRUM_STEM: 1,
|
192 |
+
OTHER_STEM: 2,
|
193 |
+
VOCAL_STEM: 3}
|
194 |
+
|
195 |
+
DEMUCS_6_SOURCE_MAPPER = {
|
196 |
+
BASS_STEM:0,
|
197 |
+
DRUM_STEM:1,
|
198 |
+
OTHER_STEM:2,
|
199 |
+
VOCAL_STEM:3,
|
200 |
+
GUITAR_STEM:4,
|
201 |
+
PIANO_STEM:5}
|
202 |
+
|
203 |
+
DEMUCS_4_SOURCE_LIST = [BASS_STEM, DRUM_STEM, OTHER_STEM, VOCAL_STEM]
|
204 |
+
DEMUCS_6_SOURCE_LIST = [BASS_STEM, DRUM_STEM, OTHER_STEM, VOCAL_STEM, GUITAR_STEM, PIANO_STEM]
|
205 |
+
|
206 |
+
DEMUCS_UVR_MODEL = 'UVR_Model'
|
207 |
+
|
208 |
+
CHOOSE_STEM_PAIR = 'Choose Stem Pair'
|
209 |
+
|
210 |
+
STEM_SET_MENU = (VOCAL_STEM,
|
211 |
+
INST_STEM,
|
212 |
+
OTHER_STEM,
|
213 |
+
BASS_STEM,
|
214 |
+
DRUM_STEM,
|
215 |
+
GUITAR_STEM,
|
216 |
+
PIANO_STEM,
|
217 |
+
SYNTH_STEM,
|
218 |
+
STRINGS_STEM,
|
219 |
+
WOODWINDS_STEM,
|
220 |
+
BRASS_STEM,
|
221 |
+
WIND_INST_STEM)
|
222 |
+
|
223 |
+
STEM_SET_MENU_ONLY = list(STEM_SET_MENU) + [OPT_SEPARATOR_SAVE, INPUT_STEM_NAME]
|
224 |
+
|
225 |
+
STEM_SET_MENU_2 = (
|
226 |
+
OTHER_STEM,
|
227 |
+
BASS_STEM,
|
228 |
+
DRUM_STEM,
|
229 |
+
GUITAR_STEM,
|
230 |
+
PIANO_STEM,
|
231 |
+
SYNTH_STEM,
|
232 |
+
STRINGS_STEM,
|
233 |
+
WOODWINDS_STEM,
|
234 |
+
BRASS_STEM,
|
235 |
+
WIND_INST_STEM,
|
236 |
+
"Noise",
|
237 |
+
"Reverb")
|
238 |
+
|
239 |
+
STEM_PAIR_MAPPER = {
|
240 |
+
VOCAL_STEM: INST_STEM,
|
241 |
+
INST_STEM: VOCAL_STEM,
|
242 |
+
LEAD_VOCAL_STEM: BV_VOCAL_STEM,
|
243 |
+
BV_VOCAL_STEM: LEAD_VOCAL_STEM,
|
244 |
+
PRIMARY_STEM: SECONDARY_STEM}
|
245 |
+
|
246 |
+
STEM_PAIR_MAPPER_FULL = {
|
247 |
+
VOCAL_STEM: INST_STEM,
|
248 |
+
INST_STEM: VOCAL_STEM,
|
249 |
+
OTHER_STEM: NO_OTHER_STEM,
|
250 |
+
BASS_STEM: NO_BASS_STEM,
|
251 |
+
DRUM_STEM: NO_DRUM_STEM,
|
252 |
+
GUITAR_STEM: NO_GUITAR_STEM,
|
253 |
+
PIANO_STEM: NO_PIANO_STEM,
|
254 |
+
SYNTH_STEM: NO_SYNTH_STEM,
|
255 |
+
STRINGS_STEM: NO_STRINGS_STEM,
|
256 |
+
WOODWINDS_STEM: NO_WOODWINDS_STEM,
|
257 |
+
BRASS_STEM: NO_BRASS_STEM,
|
258 |
+
WIND_INST_STEM: NO_WIND_INST_STEM,
|
259 |
+
NO_OTHER_STEM: OTHER_STEM,
|
260 |
+
NO_BASS_STEM: BASS_STEM,
|
261 |
+
NO_DRUM_STEM: DRUM_STEM,
|
262 |
+
NO_GUITAR_STEM: GUITAR_STEM,
|
263 |
+
NO_PIANO_STEM: PIANO_STEM,
|
264 |
+
NO_SYNTH_STEM: SYNTH_STEM,
|
265 |
+
NO_STRINGS_STEM: STRINGS_STEM,
|
266 |
+
NO_WOODWINDS_STEM: WOODWINDS_STEM,
|
267 |
+
NO_BRASS_STEM: BRASS_STEM,
|
268 |
+
NO_WIND_INST_STEM: WIND_INST_STEM,
|
269 |
+
PRIMARY_STEM: SECONDARY_STEM}
|
270 |
+
|
271 |
+
NO_STEM = "No "
|
272 |
+
|
273 |
+
NON_ACCOM_STEMS = (
|
274 |
+
VOCAL_STEM,
|
275 |
+
OTHER_STEM,
|
276 |
+
BASS_STEM,
|
277 |
+
DRUM_STEM,
|
278 |
+
GUITAR_STEM,
|
279 |
+
PIANO_STEM,
|
280 |
+
SYNTH_STEM,
|
281 |
+
STRINGS_STEM,
|
282 |
+
WOODWINDS_STEM,
|
283 |
+
BRASS_STEM,
|
284 |
+
WIND_INST_STEM)
|
285 |
+
|
286 |
+
MDX_NET_FREQ_CUT = [VOCAL_STEM, INST_STEM]
|
287 |
+
|
288 |
+
DEMUCS_4_STEM_OPTIONS = (ALL_STEMS, VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM)
|
289 |
+
DEMUCS_6_STEM_OPTIONS = (ALL_STEMS, VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM)
|
290 |
+
DEMUCS_2_STEM_OPTIONS = (VOCAL_STEM, INST_STEM)
|
291 |
+
DEMUCS_4_STEM_CHECK = (OTHER_STEM, BASS_STEM, DRUM_STEM)
|
292 |
+
|
293 |
+
#Menu Dropdowns
|
294 |
+
|
295 |
+
VOCAL_PAIR = f'{VOCAL_STEM}/{INST_STEM}'
|
296 |
+
INST_PAIR = f'{INST_STEM}/{VOCAL_STEM}'
|
297 |
+
OTHER_PAIR = f'{OTHER_STEM}/{NO_OTHER_STEM}'
|
298 |
+
DRUM_PAIR = f'{DRUM_STEM}/{NO_DRUM_STEM}'
|
299 |
+
BASS_PAIR = f'{BASS_STEM}/{NO_BASS_STEM}'
|
300 |
+
FOUR_STEM_ENSEMBLE = '4 Stem Ensemble'
|
301 |
+
MULTI_STEM_ENSEMBLE = 'Multi-stem Ensemble'
|
302 |
+
|
303 |
+
ENSEMBLE_MAIN_STEM = (CHOOSE_STEM_PAIR, VOCAL_PAIR, OTHER_PAIR, DRUM_PAIR, BASS_PAIR, FOUR_STEM_ENSEMBLE, MULTI_STEM_ENSEMBLE)
|
304 |
+
|
305 |
+
MIN_SPEC = 'Min Spec'
|
306 |
+
MAX_SPEC = 'Max Spec'
|
307 |
+
AUDIO_AVERAGE = 'Average'
|
308 |
+
|
309 |
+
MAX_MIN = f'{MAX_SPEC}/{MIN_SPEC}'
|
310 |
+
MAX_MAX = f'{MAX_SPEC}/{MAX_SPEC}'
|
311 |
+
MAX_AVE = f'{MAX_SPEC}/{AUDIO_AVERAGE}'
|
312 |
+
MIN_MAX = f'{MIN_SPEC}/{MAX_SPEC}'
|
313 |
+
MIN_MIX = f'{MIN_SPEC}/{MIN_SPEC}'
|
314 |
+
MIN_AVE = f'{MIN_SPEC}/{AUDIO_AVERAGE}'
|
315 |
+
AVE_MAX = f'{AUDIO_AVERAGE}/{MAX_SPEC}'
|
316 |
+
AVE_MIN = f'{AUDIO_AVERAGE}/{MIN_SPEC}'
|
317 |
+
AVE_AVE = f'{AUDIO_AVERAGE}/{AUDIO_AVERAGE}'
|
318 |
+
|
319 |
+
ENSEMBLE_TYPE = (MAX_MIN, MAX_MAX, MAX_AVE, MIN_MAX, MIN_MIX, MIN_AVE, AVE_MAX, AVE_MIN, AVE_AVE)
|
320 |
+
ENSEMBLE_TYPE_4_STEM = (MAX_SPEC, MIN_SPEC, AUDIO_AVERAGE)
|
321 |
+
|
322 |
+
BATCH_MODE = 'Batch Mode'
|
323 |
+
BETA_VERSION = 'BETA'
|
324 |
+
DEF_OPT = 'Default'
|
325 |
+
USER_INPUT = "User Input"
|
326 |
+
OPT_SEPARATOR = '─'*65
|
327 |
+
|
328 |
+
CHUNKS = (AUTO_SELECT, '1', '5', '10', '15', '20',
|
329 |
+
'25', '30', '35', '40', '45', '50',
|
330 |
+
'55', '60', '65', '70', '75', '80',
|
331 |
+
'85', '90', '95', 'Full')
|
332 |
+
|
333 |
+
BATCH_SIZE = (DEF_OPT, '2', '3', '4', '5',
|
334 |
+
'6', '7', '8', '9', '10')
|
335 |
+
|
336 |
+
VOL_COMPENSATION = (AUTO_SELECT, '1.035', '1.08')
|
337 |
+
|
338 |
+
MARGIN_SIZE = ('44100', '22050', '11025')
|
339 |
+
|
340 |
+
AUDIO_TOOLS = 'Audio Tools'
|
341 |
+
|
342 |
+
MANUAL_ENSEMBLE = 'Manual Ensemble'
|
343 |
+
TIME_STRETCH = 'Time Stretch'
|
344 |
+
CHANGE_PITCH = 'Change Pitch'
|
345 |
+
ALIGN_INPUTS = 'Align Inputs'
|
346 |
+
MATCH_INPUTS = 'Matchering'
|
347 |
+
COMBINE_INPUTS = 'Combine Inputs'
|
348 |
+
|
349 |
+
if OPERATING_SYSTEM == 'Windows' or OPERATING_SYSTEM == 'Darwin':
|
350 |
+
AUDIO_TOOL_OPTIONS = (MANUAL_ENSEMBLE, TIME_STRETCH, CHANGE_PITCH, ALIGN_INPUTS, MATCH_INPUTS)
|
351 |
+
else:
|
352 |
+
AUDIO_TOOL_OPTIONS = (MANUAL_ENSEMBLE, ALIGN_INPUTS, MATCH_INPUTS)
|
353 |
+
|
354 |
+
MANUAL_ENSEMBLE_OPTIONS = (MIN_SPEC, MAX_SPEC, AUDIO_AVERAGE, COMBINE_INPUTS)
|
355 |
+
|
356 |
+
PROCESS_METHODS = (VR_ARCH_PM, MDX_ARCH_TYPE, DEMUCS_ARCH_TYPE, ENSEMBLE_MODE, AUDIO_TOOLS)
|
357 |
+
|
358 |
+
DEMUCS_SEGMENTS = (DEF_OPT, '1', '5', '10', '15', '20',
|
359 |
+
'25', '30', '35', '40', '45', '50',
|
360 |
+
'55', '60', '65', '70', '75', '80',
|
361 |
+
'85', '90', '95', '100')
|
362 |
+
|
363 |
+
DEMUCS_SHIFTS = (0, 1, 2, 3, 4, 5,
|
364 |
+
6, 7, 8, 9, 10, 11,
|
365 |
+
12, 13, 14, 15, 16, 17,
|
366 |
+
18, 19, 20)
|
367 |
+
SEMI_DEF = ['0']
|
368 |
+
SEMITONE_SEL = (-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12)
|
369 |
+
|
370 |
+
NOUT_SEL = (8, 16, 32, 48, 64)
|
371 |
+
NOUT_LSTM_SEL = (64, 128)
|
372 |
+
|
373 |
+
DEMUCS_OVERLAP = (0.25, 0.50, 0.75, 0.99)
|
374 |
+
MDX_OVERLAP = (DEF_OPT, 0.25, 0.50, 0.75, 0.99)
|
375 |
+
MDX23_OVERLAP = range(2, 51)
|
376 |
+
VR_AGGRESSION = range(0, 51)
|
377 |
+
|
378 |
+
TIME_WINDOW_MAPPER = {
|
379 |
+
"None": None,
|
380 |
+
"1": [0.0625],
|
381 |
+
"2": [0.125],
|
382 |
+
"3": [0.25],
|
383 |
+
"4": [0.5],
|
384 |
+
"5": [0.75],
|
385 |
+
"6": [1],
|
386 |
+
"7": [2],
|
387 |
+
"Shifts: Low": [0.0625, 0.5],
|
388 |
+
"Shifts: Medium": [0.0625, 0.125, 0.5],
|
389 |
+
"Shifts: High": [0.0625, 0.125, 0.25, 0.5]
|
390 |
+
#"Shifts: Very High": [0.0625, 0.125, 0.25, 0.5, 0.75, 1],
|
391 |
+
}
|
392 |
+
|
393 |
+
INTRO_MAPPER = {
|
394 |
+
"Default": [10],
|
395 |
+
"1": [8],
|
396 |
+
"2": [6],
|
397 |
+
"3": [4],
|
398 |
+
"4": [2],
|
399 |
+
"Shifts: Low": [1, 10],
|
400 |
+
"Shifts: Medium": [1, 10, 8],
|
401 |
+
"Shifts: High": [1, 10, 8, 6, 4]
|
402 |
+
}
|
403 |
+
|
404 |
+
VOLUME_MAPPER = {
|
405 |
+
"None": (0, [0]),
|
406 |
+
"Low": (-4, range(0, 8)),
|
407 |
+
"Medium": (-6, range(0, 12)),
|
408 |
+
"High": (-6, [x * 0.5 for x in range(0, 25)]),
|
409 |
+
"Very High": (-10, [x * 0.5 for x in range(0, 41)])}
|
410 |
+
#"Max": (-10, [x * 0.3 for x in range(0, int(20 / 0.3) + 1)])}
|
411 |
+
|
412 |
+
PHASE_MAPPER = {
|
413 |
+
"None": [0],
|
414 |
+
"Shifts Low": [0, 180],
|
415 |
+
"Shifts Medium": [0],
|
416 |
+
"Shifts High": [0],
|
417 |
+
"Shifts Very High": [0],}
|
418 |
+
|
419 |
+
NONE_P = "None"
|
420 |
+
VLOW_P = "Shifts: Very Low"
|
421 |
+
LOW_P = "Shifts: Low"
|
422 |
+
MED_P = "Shifts: Medium"
|
423 |
+
HIGH_P = "Shifts: High"
|
424 |
+
VHIGH_P = "Shifts: Very High"
|
425 |
+
VMAX_P = "Shifts: Maximum"
|
426 |
+
|
427 |
+
PHASE_SHIFTS_OPT = {
|
428 |
+
NONE_P:190,
|
429 |
+
VLOW_P:180,
|
430 |
+
LOW_P:90,
|
431 |
+
MED_P:45,
|
432 |
+
HIGH_P:20,
|
433 |
+
VHIGH_P:10,
|
434 |
+
VMAX_P:1,}
|
435 |
+
|
436 |
+
VR_WINDOW = ('320', '512','1024')
|
437 |
+
VR_CROP = ('256', '512', '1024')
|
438 |
+
POST_PROCESSES_THREASHOLD_VALUES = ('0.1', '0.2', '0.3')
|
439 |
+
|
440 |
+
MDX_POP_PRO = ('MDX-NET_Noise_Profile_14_kHz', 'MDX-NET_Noise_Profile_17_kHz', 'MDX-NET_Noise_Profile_Full_Band')
|
441 |
+
MDX_POP_STEMS = ('Vocals', 'Instrumental', 'Other', 'Drums', 'Bass')
|
442 |
+
MDX_POP_NFFT = ('4096', '5120', '6144', '7680', '8192', '16384')
|
443 |
+
MDX_POP_DIMF = ('2048', '3072', '4096')
|
444 |
+
DENOISE_NONE, DENOISE_S, DENOISE_M = 'None', 'Standard', 'Denoise Model'
|
445 |
+
MDX_DENOISE_OPTION = [DENOISE_NONE, DENOISE_S, DENOISE_M]
|
446 |
+
MDX_SEGMENTS = list(range(32, 4000+1, 32))
|
447 |
+
|
448 |
+
SAVE_ENSEMBLE = 'Save Ensemble'
|
449 |
+
CLEAR_ENSEMBLE = 'Clear Selection(s)'
|
450 |
+
MENU_SEPARATOR = 35*'•'
|
451 |
+
CHOOSE_ENSEMBLE_OPTION = 'Choose Option'
|
452 |
+
ALL_TYPES = 'ALL'
|
453 |
+
INVALID_ENTRY = 'Invalid Input, Please Try Again'
|
454 |
+
ENSEMBLE_INPUT_RULE = '1. Only letters, numbers, spaces, and dashes allowed.\n2. No dashes or spaces at the start or end of input.'
|
455 |
+
STEM_INPUT_RULE = '1. Only words with no spaces are allowed.\n2. No spaces, numbers, or special characters.'
|
456 |
+
|
457 |
+
ENSEMBLE_OPTIONS = [OPT_SEPARATOR_SAVE, SAVE_ENSEMBLE, CLEAR_ENSEMBLE]
|
458 |
+
ENSEMBLE_CHECK = 'ensemble check'
|
459 |
+
KARAOKEE_CHECK = 'kara check'
|
460 |
+
|
461 |
+
AUTO_PHASE = "Automatic"
|
462 |
+
POSITIVE_PHASE = "Positive Phase"
|
463 |
+
NEGATIVE_PHASE = "Negative Phase"
|
464 |
+
OFF_PHASE = "Native Phase"
|
465 |
+
|
466 |
+
ALIGN_PHASE_OPTIONS = [AUTO_PHASE, POSITIVE_PHASE, NEGATIVE_PHASE, OFF_PHASE]
|
467 |
+
|
468 |
+
SELECT_SAVED_ENSEMBLE = 'Select Saved Ensemble'
|
469 |
+
SELECT_SAVED_SETTING = 'Select Saved Setting'
|
470 |
+
ENSEMBLE_OPTION = "Ensemble Customization Options"
|
471 |
+
MDX_OPTION = "Advanced MDX-Net Options"
|
472 |
+
DEMUCS_OPTION = "Advanced Demucs Options"
|
473 |
+
VR_OPTION = "Advanced VR Options"
|
474 |
+
HELP_OPTION = "Open Information Guide"
|
475 |
+
ERROR_OPTION = "Open Error Log"
|
476 |
+
VERIFY_BEGIN = 'Verifying file '
|
477 |
+
SAMPLE_BEGIN = 'Creating Sample '
|
478 |
+
MODEL_MISSING_CHECK = 'Model Missing:'
|
479 |
+
OPTION_LIST = [VR_OPTION, MDX_OPTION, DEMUCS_OPTION, ENSEMBLE_OPTION, ALIGNMENT_TOOL, HELP_OPTION, ERROR_OPTION]
|
480 |
+
|
481 |
+
#Menu Strings
|
482 |
+
VR_MENU ='VR Menu'
|
483 |
+
DEMUCS_MENU ='Demucs Menu'
|
484 |
+
MDX_MENU ='MDX-Net Menu'
|
485 |
+
ENSEMBLE_MENU ='Ensemble Menu'
|
486 |
+
HELP_MENU ='Help Menu'
|
487 |
+
ERROR_MENU ='Error Log'
|
488 |
+
INPUTS_MENU ='Inputs Menu'
|
489 |
+
ALIGN_MENU ='Align Menu'
|
490 |
+
|
491 |
+
# Audio Player
|
492 |
+
PLAYING_SONG = ": Playing"
|
493 |
+
PAUSE_SONG = ": Paused"
|
494 |
+
STOP_SONG = ": Stopped"
|
495 |
+
|
496 |
+
SELECTED_VER = 'Selected'
|
497 |
+
DETECTED_VER = 'Detected'
|
498 |
+
|
499 |
+
SAMPLE_MODE_CHECKBOX = lambda v:f'Sample Mode ({v}s)'
|
500 |
+
REMOVED_FILES = lambda r, e:f'Audio Input Verification Report:\n\nRemoved Files:\n\n{r}\n\nError Details:\n\n{e}'
|
501 |
+
ADVANCED_SETTINGS = (ENSEMBLE_OPTION, MDX_OPTION, DEMUCS_OPTION, VR_OPTION, HELP_OPTION, ERROR_OPTION)
|
502 |
+
|
503 |
+
WAV = 'WAV'
|
504 |
+
FLAC = 'FLAC'
|
505 |
+
MP3 = 'MP3'
|
506 |
+
|
507 |
+
MP3_BIT_RATES = ('96k', '128k', '160k', '224k', '256k', '320k')
|
508 |
+
WAV_TYPE = ('PCM_U8', 'PCM_16', 'PCM_24', 'PCM_32', '32-bit Float', '64-bit Float')
|
509 |
+
GPU_DEVICE_NUM_OPTS = (DEFAULT, '0', '1', '2', '3', '4', '5', '6', '7', '8')
|
510 |
+
|
511 |
+
SELECT_SAVED_SET = 'Choose Option'
|
512 |
+
SAVE_SETTINGS = 'Save Current Settings'
|
513 |
+
RESET_TO_DEFAULT = 'Reset to Default'
|
514 |
+
RESET_FULL_TO_DEFAULT = 'Reset to Default'
|
515 |
+
RESET_PM_TO_DEFAULT = 'Reset All Application Settings to Default'
|
516 |
+
|
517 |
+
SAVE_SET_OPTIONS = [OPT_SEPARATOR_SAVE, SAVE_SETTINGS, RESET_TO_DEFAULT]
|
518 |
+
|
519 |
+
TIME_PITCH = ('1.0', '2.0', '3.0', '4.0')
|
520 |
+
TIME_TEXT = '_time_stretched'
|
521 |
+
PITCH_TEXT = '_pitch_shifted'
|
522 |
+
|
523 |
+
#RegEx Input Validation
|
524 |
+
REG_PITCH = r'^[-+]?(1[0]|[0-9]([.][0-9]*)?)$'
|
525 |
+
REG_TIME = r'^[+]?(1[0]|[0-9]([.][0-9]*)?)$'
|
526 |
+
REG_COMPENSATION = r'\b^(1[0]|[0-9]([.][0-9]*)?|Auto|None)$\b'
|
527 |
+
REG_THES_POSTPORCESS = r'\b^([0]([.][0-9]{0,6})?)$\b'
|
528 |
+
REG_CHUNKS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Auto|Full)$\b'
|
529 |
+
REG_CHUNKS_DEMUCS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Auto|Full)$\b'
|
530 |
+
REG_MARGIN = r'\b^[0-9]*$\b'
|
531 |
+
REG_SEGMENTS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Default)$\b'
|
532 |
+
REG_SAVE_INPUT = r'\b^([a-zA-Z0-9 -]{0,25})$\b'
|
533 |
+
REG_INPUT_STEM_NAME = r'^(Wind Inst|[a-zA-Z]{1,25})$'
|
534 |
+
REG_SEMITONES = r'^-?(20\.00|[01]?\d(\.\d{1,2})?|20)$'
|
535 |
+
REG_AGGRESSION = r'^[-+]?[0-9]\d*?$'
|
536 |
+
REG_WINDOW = r'\b^[0-9]{0,4}$\b'
|
537 |
+
REG_SHIFTS = r'\b^[0-9]*$\b'
|
538 |
+
REG_BATCHES = r'\b^([0-9]*?|Default)$\b'
|
539 |
+
REG_OVERLAP = r'\b^([0]([.][0-9]{0,6})?|Default)$\b'#r"(Default|[0-9]+(\.[0-9]+)?)"#
|
540 |
+
REG_OVERLAP23 = r'\b^([1][0-9]|[2-9][0-9]*|Default)$\b'#r'\b^([2-9][0-9]*?|Default)$\b'
|
541 |
+
REG_MDX_SEG = r'\b(?:' + '|'.join([str(num) for num in range(32, 1000001, 32)]) + r')\b'
|
542 |
+
REG_ALIGN = r'^[-+]?[0-9]\d*?$'
|
543 |
+
REG_VOL_COMP = r'^\d+\.\d{1,9}$'
|
544 |
+
|
545 |
+
# Sub Menu
|
546 |
+
VR_ARCH_SETTING_LOAD = 'Load for VR Arch'
|
547 |
+
MDX_SETTING_LOAD = 'Load for MDX-Net'
|
548 |
+
DEMUCS_SETTING_LOAD = 'Load for Demucs'
|
549 |
+
ALL_ARCH_SETTING_LOAD = 'Load for Full Application'
|
550 |
+
|
551 |
+
# Mappers
|
552 |
+
|
553 |
+
DEFAULT_DATA = {
|
554 |
+
'chosen_process_method': MDX_ARCH_TYPE,
|
555 |
+
'vr_model': CHOOSE_MODEL,
|
556 |
+
'aggression_setting': 5,
|
557 |
+
'window_size': 512,
|
558 |
+
'mdx_segment_size': 256,
|
559 |
+
'batch_size': DEF_OPT,
|
560 |
+
'crop_size': 256,
|
561 |
+
'is_tta': False,
|
562 |
+
'is_output_image': False,
|
563 |
+
'is_post_process': False,
|
564 |
+
'is_high_end_process': False,
|
565 |
+
'post_process_threshold': 0.2,
|
566 |
+
'vr_voc_inst_secondary_model': NO_MODEL,
|
567 |
+
'vr_other_secondary_model': NO_MODEL,
|
568 |
+
'vr_bass_secondary_model': NO_MODEL,
|
569 |
+
'vr_drums_secondary_model': NO_MODEL,
|
570 |
+
'vr_is_secondary_model_activate': False,
|
571 |
+
'vr_voc_inst_secondary_model_scale': 0.9,
|
572 |
+
'vr_other_secondary_model_scale': 0.7,
|
573 |
+
'vr_bass_secondary_model_scale': 0.5,
|
574 |
+
'vr_drums_secondary_model_scale': 0.5,
|
575 |
+
'demucs_model': CHOOSE_MODEL,
|
576 |
+
'segment': DEMUCS_SEGMENTS[0],
|
577 |
+
'overlap': DEMUCS_OVERLAP[0],
|
578 |
+
'overlap_mdx': MDX_OVERLAP[0],
|
579 |
+
'overlap_mdx23': '8',
|
580 |
+
'shifts': 2,
|
581 |
+
'chunks_demucs': CHUNKS[0],
|
582 |
+
'margin_demucs': 44100,
|
583 |
+
'is_chunk_demucs': False,
|
584 |
+
'is_chunk_mdxnet': False,
|
585 |
+
'is_primary_stem_only_Demucs': False,
|
586 |
+
'is_secondary_stem_only_Demucs': False,
|
587 |
+
'is_split_mode': True,
|
588 |
+
'is_demucs_combine_stems': True,#
|
589 |
+
'is_mdx23_combine_stems': True,#
|
590 |
+
'demucs_voc_inst_secondary_model': NO_MODEL,
|
591 |
+
'demucs_other_secondary_model': NO_MODEL,
|
592 |
+
'demucs_bass_secondary_model': NO_MODEL,
|
593 |
+
'demucs_drums_secondary_model': NO_MODEL,
|
594 |
+
'demucs_is_secondary_model_activate': False,
|
595 |
+
'demucs_voc_inst_secondary_model_scale': 0.9,
|
596 |
+
'demucs_other_secondary_model_scale': 0.7,
|
597 |
+
'demucs_bass_secondary_model_scale': 0.5,
|
598 |
+
'demucs_drums_secondary_model_scale': 0.5,
|
599 |
+
'demucs_stems': ALL_STEMS,
|
600 |
+
'demucs_pre_proc_model': NO_MODEL,
|
601 |
+
'is_demucs_pre_proc_model_activate': False,
|
602 |
+
'is_demucs_pre_proc_model_inst_mix': False,
|
603 |
+
'mdx_net_model': CHOOSE_MODEL,
|
604 |
+
'chunks': CHUNKS[0],
|
605 |
+
'margin': 44100,
|
606 |
+
'compensate': AUTO_SELECT,
|
607 |
+
'is_denoise': False,#
|
608 |
+
'denoise_option': 'None',#
|
609 |
+
'phase_option': AUTO_PHASE,
|
610 |
+
'phase_shifts': NONE_P,#
|
611 |
+
'is_save_align': False,#,
|
612 |
+
'is_match_frequency_pitch': True,#
|
613 |
+
'is_match_silence': True,#
|
614 |
+
'is_spec_match': False,#
|
615 |
+
'is_mdx_c_seg_def': False,
|
616 |
+
'is_invert_spec': False, #
|
617 |
+
'is_deverb_vocals': False, #
|
618 |
+
'deverb_vocal_opt': 'Main Vocals Only', #
|
619 |
+
'voc_split_save_opt': 'Lead Only', #
|
620 |
+
'is_mixer_mode': False,
|
621 |
+
'mdx_batch_size': DEF_OPT,
|
622 |
+
'mdx_voc_inst_secondary_model': NO_MODEL,
|
623 |
+
'mdx_other_secondary_model': NO_MODEL,
|
624 |
+
'mdx_bass_secondary_model': NO_MODEL,
|
625 |
+
'mdx_drums_secondary_model': NO_MODEL,
|
626 |
+
'mdx_is_secondary_model_activate': False,
|
627 |
+
'mdx_voc_inst_secondary_model_scale': 0.9,
|
628 |
+
'mdx_other_secondary_model_scale': 0.7,
|
629 |
+
'mdx_bass_secondary_model_scale': 0.5,
|
630 |
+
'mdx_drums_secondary_model_scale': 0.5,
|
631 |
+
'mdx_stems': ALL_STEMS,
|
632 |
+
'is_save_all_outputs_ensemble': True,
|
633 |
+
'is_append_ensemble_name': False,
|
634 |
+
'chosen_audio_tool': AUDIO_TOOL_OPTIONS[0],
|
635 |
+
'choose_algorithm': MANUAL_ENSEMBLE_OPTIONS[0],
|
636 |
+
'time_stretch_rate': 2.0,
|
637 |
+
'pitch_rate': 2.0,
|
638 |
+
'is_time_correction': True,
|
639 |
+
'is_gpu_conversion': False,
|
640 |
+
'is_primary_stem_only': False,
|
641 |
+
'is_secondary_stem_only': False,
|
642 |
+
'is_testing_audio': False,#
|
643 |
+
'is_auto_update_model_params': True,#
|
644 |
+
'is_add_model_name': False,
|
645 |
+
'is_accept_any_input': False,
|
646 |
+
'is_task_complete': False,
|
647 |
+
'is_normalization': False,
|
648 |
+
'is_use_opencl': False,
|
649 |
+
'is_wav_ensemble': False,
|
650 |
+
'is_create_model_folder': False,
|
651 |
+
'mp3_bit_set': '320k',#
|
652 |
+
'semitone_shift': '0',#
|
653 |
+
'save_format': WAV,
|
654 |
+
'wav_type_set': 'PCM_16',
|
655 |
+
'device_set': DEFAULT,
|
656 |
+
'user_code': '',
|
657 |
+
'export_path': '',
|
658 |
+
'input_paths': [],
|
659 |
+
'lastDir': None,
|
660 |
+
'time_window': "3",
|
661 |
+
'intro_analysis': DEFAULT,
|
662 |
+
'db_analysis': "Medium",
|
663 |
+
'fileOneEntry': '',
|
664 |
+
'fileOneEntry_Full': '',
|
665 |
+
'fileTwoEntry': '',
|
666 |
+
'fileTwoEntry_Full': '',
|
667 |
+
'DualBatch_inputPaths': [],
|
668 |
+
'model_hash_table': {},
|
669 |
+
'help_hints_var': True,
|
670 |
+
'set_vocal_splitter': NO_MODEL,
|
671 |
+
'is_set_vocal_splitter': False,#
|
672 |
+
'is_save_inst_set_vocal_splitter': False,#
|
673 |
+
'model_sample_mode': False,
|
674 |
+
'model_sample_mode_duration': 30
|
675 |
+
}
|
676 |
+
|
677 |
+
SETTING_CHECK = ('vr_model',
|
678 |
+
'aggression_setting',
|
679 |
+
'window_size',
|
680 |
+
'mdx_segment_size',
|
681 |
+
'batch_size',
|
682 |
+
'crop_size',
|
683 |
+
'is_tta',
|
684 |
+
'is_output_image',
|
685 |
+
'is_post_process',
|
686 |
+
'is_high_end_process',
|
687 |
+
'post_process_threshold',
|
688 |
+
'vr_voc_inst_secondary_model',
|
689 |
+
'vr_other_secondary_model',
|
690 |
+
'vr_bass_secondary_model',
|
691 |
+
'vr_drums_secondary_model',
|
692 |
+
'vr_is_secondary_model_activate',
|
693 |
+
'vr_voc_inst_secondary_model_scale',
|
694 |
+
'vr_other_secondary_model_scale',
|
695 |
+
'vr_bass_secondary_model_scale',
|
696 |
+
'vr_drums_secondary_model_scale',
|
697 |
+
'demucs_model',
|
698 |
+
'segment',
|
699 |
+
'overlap',
|
700 |
+
'overlap_mdx',
|
701 |
+
'shifts',
|
702 |
+
'chunks_demucs',
|
703 |
+
'margin_demucs',
|
704 |
+
'is_chunk_demucs',
|
705 |
+
'is_primary_stem_only_Demucs',
|
706 |
+
'is_secondary_stem_only_Demucs',
|
707 |
+
'is_split_mode',
|
708 |
+
'is_demucs_combine_stems',#
|
709 |
+
'is_mdx23_combine_stems',#
|
710 |
+
'demucs_voc_inst_secondary_model',
|
711 |
+
'demucs_other_secondary_model',
|
712 |
+
'demucs_bass_secondary_model',
|
713 |
+
'demucs_drums_secondary_model',
|
714 |
+
'demucs_is_secondary_model_activate',
|
715 |
+
'demucs_voc_inst_secondary_model_scale',
|
716 |
+
'demucs_other_secondary_model_scale',
|
717 |
+
'demucs_bass_secondary_model_scale',
|
718 |
+
'demucs_drums_secondary_model_scale',
|
719 |
+
'demucs_stems',
|
720 |
+
'mdx_net_model',
|
721 |
+
'chunks',
|
722 |
+
'margin',
|
723 |
+
'compensate',
|
724 |
+
'is_denoise',#
|
725 |
+
'denoise_option',#
|
726 |
+
'phase_option',#
|
727 |
+
'phase_shifts',#
|
728 |
+
'is_save_align',#,
|
729 |
+
'is_match_silence',
|
730 |
+
'is_spec_match',#,
|
731 |
+
'is_match_frequency_pitch',#
|
732 |
+
'is_mdx_c_seg_def',
|
733 |
+
'is_invert_spec',#
|
734 |
+
'is_deverb_vocals',#
|
735 |
+
'deverb_vocal_opt',#
|
736 |
+
'voc_split_save_opt',#
|
737 |
+
'mdx_batch_size',
|
738 |
+
'mdx_voc_inst_secondary_model',
|
739 |
+
'mdx_other_secondary_model',
|
740 |
+
'mdx_bass_secondary_model',
|
741 |
+
'mdx_drums_secondary_model',
|
742 |
+
'mdx_is_secondary_model_activate',
|
743 |
+
'mdx_voc_inst_secondary_model_scale',
|
744 |
+
'mdx_other_secondary_model_scale',
|
745 |
+
'mdx_bass_secondary_model_scale',
|
746 |
+
'mdx_drums_secondary_model_scale',
|
747 |
+
'is_save_all_outputs_ensemble',
|
748 |
+
'is_append_ensemble_name',
|
749 |
+
'chosen_audio_tool',
|
750 |
+
'choose_algorithm',
|
751 |
+
'time_stretch_rate',
|
752 |
+
'pitch_rate',
|
753 |
+
'is_time_correction',
|
754 |
+
'is_primary_stem_only',
|
755 |
+
'is_secondary_stem_only',
|
756 |
+
'is_testing_audio',#
|
757 |
+
'is_auto_update_model_params',#
|
758 |
+
'is_add_model_name',
|
759 |
+
"is_accept_any_input",
|
760 |
+
'is_task_complete',
|
761 |
+
'is_create_model_folder',
|
762 |
+
'mp3_bit_set',#
|
763 |
+
'semitone_shift',#
|
764 |
+
'save_format',
|
765 |
+
'wav_type_set',
|
766 |
+
'device_set',
|
767 |
+
'user_code',
|
768 |
+
'is_gpu_conversion',
|
769 |
+
'is_normalization',
|
770 |
+
'is_use_opencl',
|
771 |
+
'is_wav_ensemble',
|
772 |
+
'help_hints_var',
|
773 |
+
'set_vocal_splitter',
|
774 |
+
'is_set_vocal_splitter',#
|
775 |
+
'is_save_inst_set_vocal_splitter',#
|
776 |
+
'model_sample_mode',
|
777 |
+
'model_sample_mode_duration',
|
778 |
+
'time_window',
|
779 |
+
'intro_analysis',
|
780 |
+
'db_analysis',
|
781 |
+
'fileOneEntry',
|
782 |
+
'fileOneEntry_Full',
|
783 |
+
'fileTwoEntry',
|
784 |
+
'fileTwoEntry_Full',
|
785 |
+
'DualBatch_inputPaths'
|
786 |
+
)
|
787 |
+
|
788 |
+
NEW_LINES = "\n\n"
|
789 |
+
NEW_LINE = "\n"
|
790 |
+
NO_LINE = ''
|
791 |
+
|
792 |
+
FFMPEG_EXT = (".aac", ".aiff", ".alac" ,".flac", ".FLAC", ".mov", ".mp4", ".MP4",
|
793 |
+
".m4a", ".M4A", ".mp2", ".mp3", "MP3", ".mpc", ".mpc8",
|
794 |
+
".mpeg", ".ogg", ".OGG", ".tta", ".wav", ".wave", ".WAV", ".WAVE", ".wma", ".webm", ".eac3", ".mkv", ".opus", ".OPUS")
|
795 |
+
|
796 |
+
FFMPEG_MORE_EXT = (".aa", ".aac", ".ac3", ".aiff", ".alac", ".avi", ".f4v",".flac", ".flic", ".flv",
|
797 |
+
".m4v",".mlv", ".mov", ".mp4", ".m4a", ".mp2", ".mp3", ".mp4", ".mpc", ".mpc8",
|
798 |
+
".mpeg", ".ogg", ".tta", ".tty", ".vcd", ".wav", ".wma")
|
799 |
+
ANY_EXT = ""
|
800 |
+
|
801 |
+
# Secondary Menu Constants
|
802 |
+
|
803 |
+
VOCAL_PAIR_PLACEMENT = 1, 2, 3, 4
|
804 |
+
OTHER_PAIR_PLACEMENT = 5, 6, 7, 8
|
805 |
+
BASS_PAIR_PLACEMENT = 9, 10, 11, 12
|
806 |
+
DRUMS_PAIR_PLACEMENT = 13, 14, 15, 16
|
807 |
+
|
808 |
+
# Drag n Drop String Checks
|
809 |
+
|
810 |
+
DOUBLE_BRACKET = "} {"
|
811 |
+
RIGHT_BRACKET = "}"
|
812 |
+
LEFT_BRACKET = "{"
|
813 |
+
#DND CONSTS
|
814 |
+
|
815 |
+
MAC_DND_CHECK = ('/Users/',
|
816 |
+
'/Applications/',
|
817 |
+
'/Library/',
|
818 |
+
'/System/')
|
819 |
+
LINUX_DND_CHECK = ('/home/',
|
820 |
+
'/usr/')
|
821 |
+
WINDOWS_DND_CHECK = ('A:', 'B:', 'C:', 'D:', 'E:', 'F:', 'G:', 'H:', 'I:', 'J:', 'K:', 'L:', 'M:', 'N:', 'O:', 'P:', 'Q:', 'R:', 'S:', 'T:', 'U:', 'V:', 'W:', 'X:', 'Y:', 'Z:')
|
822 |
+
|
823 |
+
WOOD_INST_MODEL_HASH = '0ec76fd9e65f81d8b4fbd13af4826ed8'
|
824 |
+
WOOD_INST_PARAMS = {
|
825 |
+
"vr_model_param": "4band_v3",
|
826 |
+
"primary_stem": NO_WIND_INST_STEM
|
827 |
+
}
|
828 |
+
|
829 |
+
READ_ONLY = 'readonly'
|
830 |
+
|
831 |
+
FILE_1 = 'file1'
|
832 |
+
FILE_2 = 'file2'
|
833 |
+
|
834 |
+
FILE_1_LB = 'file1_lb'
|
835 |
+
FILE_2_LB = 'file1_2b'
|
836 |
+
BATCH_MODE_DUAL = " : Batch Mode"
|
837 |
+
|
838 |
+
CODEC_DICT = {
|
839 |
+
'PCM_U8': {"sample_width": 1, "codec": None}, # 8-bit unsigned PCM
|
840 |
+
'PCM_16': {"sample_width": 2, "codec": None}, # 16-bit signed PCM
|
841 |
+
'PCM_24': {"sample_width": 3, "codec": None}, # 24-bit signed PCM
|
842 |
+
'PCM_32': {"sample_width": 4, "codec": None}, # 32-bit signed PCM
|
843 |
+
'FLOAT32': {"sample_width": None, "codec": "pcm_f32le"}, # 32-bit float
|
844 |
+
'FLOAT64': {"sample_width": None, "codec": "pcm_f64le"} # 64-bit float
|
845 |
+
}
|
846 |
+
|
847 |
+
|
848 |
+
# Manual Downloads
|
849 |
+
VR_PLACEMENT_TEXT = 'Place models in \"models/VR_Models\" directory.'
|
850 |
+
MDX_PLACEMENT_TEXT = 'Place models in \"models/MDX_Net_Models\" directory.'
|
851 |
+
DEMUCS_PLACEMENT_TEXT = 'Place models in \"models/Demucs_Models\" directory.'
|
852 |
+
DEMUCS_V3_V4_PLACEMENT_TEXT = 'Place items in \"models/Demucs_Models/v3_v4_repo\" directory.'
|
853 |
+
MDX_23_NAME = "MDX23C Model"
|
854 |
+
|
855 |
+
# Liscense info
|
856 |
+
if OPERATING_SYSTEM=="Darwin":
|
857 |
+
is_macos = True
|
858 |
+
LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running macOS Catalina and above.\n' +\
|
859 |
+
'• Application functionality for systems running macOS Mojave or lower is not guaranteed.\n' +\
|
860 |
+
'• Application functionality for older or budget Mac systems is not guaranteed.\n\n'
|
861 |
+
elif OPERATING_SYSTEM=="Linux":
|
862 |
+
LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running Linux Ubuntu 18.04+.\n' +\
|
863 |
+
'• Application functionality for systems running other Linux platforms is not guaranteed.\n' +\
|
864 |
+
'• Application functionality for older or budget systems is not guaranteed.\n\n'
|
865 |
+
elif OPERATING_SYSTEM=="Windows":
|
866 |
+
LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running Windows 10 or higher.\n' +\
|
867 |
+
'• Application functionality for systems running Windows 7 or lower is not guaranteed.\n' +\
|
868 |
+
'• Application functionality for Intel Pentium & Celeron CPUs systems is not guaranteed.\n\n'
|
869 |
+
|
870 |
+
LICENSE_TEXT = lambda a, p:f'Current Application Version: Ultimate Vocal Remover {a}\n' +\
|
871 |
+
f'Current Patch Version: {p}\n\n' +\
|
872 |
+
'Copyright (c) 2022 Ultimate Vocal Remover\n\n' +\
|
873 |
+
'UVR is free and open-source, but MIT licensed. Please credit us if you use our\n' +\
|
874 |
+
f'models or code for projects unrelated to UVR.\n\n{LICENSE_OS_SPECIFIC_TEXT}' +\
|
875 |
+
'This bundle contains the UVR interface, Python, PyTorch, and other\n' +\
|
876 |
+
'dependencies needed to run the application effectively.\n\n' +\
|
877 |
+
'Website Links: This application, System or Service(s) may contain links to\n' +\
|
878 |
+
'other websites and downloads, and they are solely provided to you as an\n' +\
|
879 |
+
'additional convenience. You understand and acknowledge that by clicking\n' +\
|
880 |
+
'or activating such links you are accessing a site or service outside of\n' +\
|
881 |
+
'this application, and that we do not screen, review, approve, or otherwise\n' +\
|
882 |
+
'endorse any content or information contained in these linked websites.\n' +\
|
883 |
+
'You acknowledge and agree that we, our affiliates and partners are not\n' +\
|
884 |
+
'responsible for the contents of any of these linked websites, including\n' +\
|
885 |
+
'the accuracy or availability of information provided by the linked websites,\n' +\
|
886 |
+
'and we make no representations or warranties regarding your use of\n' +\
|
887 |
+
'the linked websites.\n\n' +\
|
888 |
+
'This application is MIT Licensed\n\n' +\
|
889 |
+
'Permission is hereby granted, free of charge, to any person obtaining a copy\n' +\
|
890 |
+
'of this software and associated documentation files (the "Software"), to deal\n' +\
|
891 |
+
'in the Software without restriction, including without limitation the rights\n' +\
|
892 |
+
'to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n' +\
|
893 |
+
'copies of the Software, and to permit persons to whom the Software is\n' +\
|
894 |
+
'furnished to do so, subject to the following conditions:\n\n' +\
|
895 |
+
'The above copyright notice and this permission notice shall be included in all\n' +\
|
896 |
+
'copies or substantial portions of the Software.\n\n' +\
|
897 |
+
'THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n' +\
|
898 |
+
'IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n' +\
|
899 |
+
'FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n' +\
|
900 |
+
'AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n' +\
|
901 |
+
'LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n' +\
|
902 |
+
'OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n' +\
|
903 |
+
'SOFTWARE.'
|
904 |
+
|
905 |
+
# Message Box Text
|
906 |
+
INVALID_INPUT = 'Invalid Input', 'The input is invalid.\n\nPlease verify the input still exists or is valid and try again.'
|
907 |
+
INVALID_EXPORT = 'Invalid Export Directory', 'You have selected an invalid export directory.\n\nPlease make sure the selected directory still exists.'
|
908 |
+
INVALID_ENSEMBLE = 'Not Enough Models', 'You must select 2 or more models to run ensemble.'
|
909 |
+
INVALID_MODEL = 'No Model Chosen', 'You must select an model to continue.'
|
910 |
+
MISSING_MODEL = 'Model Missing', 'The selected model is missing or not valid.'
|
911 |
+
ERROR_OCCURED = 'Error Occured', '\n\nWould you like to open the error log for more details?\n'
|
912 |
+
PROCESS_COMPLETE = '\nProcess complete\n'
|
913 |
+
PROCESS_COMPLETE_2 = 'Process complete\n'
|
914 |
+
|
915 |
+
# GUI Text Constants
|
916 |
+
BACK_TO_MAIN_MENU = 'Back to Main Menu'
|
917 |
+
|
918 |
+
# Help Hint Text
|
919 |
+
INTERNAL_MODEL_ATT = 'This is an internal model setting. \n\n***Avoid changing it unless you\'re certain about it!***'
|
920 |
+
STOP_HELP = 'Stops ongoing tasks.\n• A confirmation pop-up will appear before stopping.'
|
921 |
+
SETTINGS_HELP = 'Accesses the main settings and the "Download Center."'
|
922 |
+
COMMAND_TEXT_HELP = 'Shows the status and progress of ongoing tasks.'
|
923 |
+
SAVE_CURRENT_SETTINGS_HELP = 'Load or save the app\'s settings.'
|
924 |
+
PITCH_SHIFT_HELP = ('Choose the pitch for processing tracks:\n\n'
|
925 |
+
'• Whole numbers indicate semitones.\n'
|
926 |
+
'• Using higher pitches may cut the upper bandwidth, even in high-quality models.\n'
|
927 |
+
'• Upping the pitch can be better for tracks with deeper vocals.\n'
|
928 |
+
'• Dropping the pitch may take more processing time but works well for tracks with high-pitched vocals.')
|
929 |
+
AGGRESSION_SETTING_HELP = ('Adjust the intensity of primary stem extraction:\n\n'
|
930 |
+
'• It ranges from -100 - 100.\n'
|
931 |
+
'• Bigger values mean deeper extractions.\n'
|
932 |
+
'• Typically, it\'s set to 5 for vocals & instrumentals. \n'
|
933 |
+
'• Values beyond 5 might muddy the sound for non-vocal models.')
|
934 |
+
WINDOW_SIZE_HELP = ('Select window size to balance quality and speed:\n\n'
|
935 |
+
'• 1024 - Quick but lesser quality.\n'
|
936 |
+
'• 512 - Medium speed and quality.\n'
|
937 |
+
'• 320 - Takes longer but may offer better quality.')
|
938 |
+
MDX_SEGMENT_SIZE_HELP = ('Pick a segment size to balance speed, resource use, and quality:\n'
|
939 |
+
'• Smaller sizes consume less resources.\n'
|
940 |
+
'• Bigger sizes consume more resources, but may provide better results.\n'
|
941 |
+
'• Default size is 256. Quality can change based on your pick.')
|
942 |
+
DEMUCS_STEMS_HELP = ('Select a stem for extraction with the chosen model:\n\n'
|
943 |
+
'• All Stems - Extracts all available stems.\n'
|
944 |
+
'• Vocals - Only the "vocals" stem.\n'
|
945 |
+
'• Other - Only the "other" stem.\n'
|
946 |
+
'• Bass - Only the "bass" stem.\n'
|
947 |
+
'• Drums - Only the "drums" stem.')
|
948 |
+
SEGMENT_HELP = ('Adjust segments to manage RAM or V-RAM usage:\n\n'
|
949 |
+
'• Smaller sizes consume less resources.\n'
|
950 |
+
'• Bigger sizes consume more resources, but may provide better results.\n'
|
951 |
+
'• "Default" picks the optimal size.')
|
952 |
+
|
953 |
+
ENSEMBLE_MAIN_STEM_HELP = (
|
954 |
+
'Select the stem type for ensembling:\n\n'
|
955 |
+
|
956 |
+
f'• {VOCAL_PAIR}:\n'
|
957 |
+
' - Primary Stem: Vocals\n'
|
958 |
+
' - Secondary Stem: Instrumental (mixture minus vocals)\n\n'
|
959 |
+
|
960 |
+
f'• {OTHER_PAIR}:\n'
|
961 |
+
' - Primary Stem: Other\n'
|
962 |
+
' - Secondary Stem: No Other (mixture minus "other")\n\n'
|
963 |
+
|
964 |
+
f'• {BASS_PAIR}:\n'
|
965 |
+
' - Primary Stem: Bass\n'
|
966 |
+
' - Secondary Stem: No Bass (mixture minus bass)\n\n'
|
967 |
+
|
968 |
+
f'• {DRUM_PAIR}:\n'
|
969 |
+
' - Primary Stem: Drums\n'
|
970 |
+
' - Secondary Stem: No Drums (mixture minus drums)\n\n'
|
971 |
+
|
972 |
+
f'• {FOUR_STEM_ENSEMBLE}:\n'
|
973 |
+
' - Gathers all 4-stem Demucs models and ensembles all outputs.\n\n'
|
974 |
+
|
975 |
+
f'• {MULTI_STEM_ENSEMBLE}:\n'
|
976 |
+
' - The "Jungle Ensemble" gathers all models and ensembles any related outputs.'
|
977 |
+
)
|
978 |
+
|
979 |
+
ENSEMBLE_TYPE_HELP = (
|
980 |
+
'Choose the ensemble algorithm for generating the final output:\n\n'
|
981 |
+
|
982 |
+
f'• {MAX_MIN}:\n'
|
983 |
+
' - Primary stem processed with "Max Spec" algorithm.\n'
|
984 |
+
' - Secondary stem processed with "Min Spec" algorithm.\n\n'
|
985 |
+
|
986 |
+
'Note: For the "4 Stem Ensemble" option, only one algorithm will be displayed.\n\n'
|
987 |
+
|
988 |
+
'Algorithm Details:\n'
|
989 |
+
|
990 |
+
f'• {MAX_SPEC}:\n'
|
991 |
+
' - Produces the highest possible output.\n'
|
992 |
+
' - Ideal for vocal stems for a fuller sound, but might introduce unwanted artifacts.\n'
|
993 |
+
' - Works well with instrumental stems, but avoid using VR Arch models in the ensemble.\n\n'
|
994 |
+
|
995 |
+
f'• {MIN_SPEC}:\n'
|
996 |
+
' - Produces the lowest possible output.\n'
|
997 |
+
' - Ideal for instrumental stems for a cleaner result. Might result in a "muddy" sound.\n\n'
|
998 |
+
|
999 |
+
f'• {AUDIO_AVERAGE}:\n'
|
1000 |
+
' - Averages all results together for the final output.'
|
1001 |
+
)
|
1002 |
+
|
1003 |
+
ENSEMBLE_LISTBOX_HELP = (
|
1004 |
+
'Displays all available models for the chosen main stem pair.'
|
1005 |
+
)
|
1006 |
+
|
1007 |
+
if OPERATING_SYSTEM == 'darwin':
|
1008 |
+
IS_GPU_CONVERSION_HELP = (
|
1009 |
+
'• Use GPU for Processing (if available):\n'
|
1010 |
+
' - If checked, the application will attempt to use your GPU for faster processing.\n'
|
1011 |
+
' - If a GPU is not detected, it will default to CPU processing.\n'
|
1012 |
+
' - GPU processing for MacOS only works with VR Arch models.\n\n'
|
1013 |
+
'• Please Note:\n'
|
1014 |
+
' - CPU processing is significantly slower than GPU processing.\n'
|
1015 |
+
' - Only Macs with M1 chips can be used for GPU processing.'
|
1016 |
+
)
|
1017 |
+
else:
|
1018 |
+
IS_GPU_CONVERSION_HELP = (
|
1019 |
+
'• Use GPU for Processing (if available):\n'
|
1020 |
+
' - If checked, the application will attempt to use your GPU for faster processing.\n'
|
1021 |
+
' - If a GPU is not detected, it will default to CPU processing.\n\n'
|
1022 |
+
'• Please Note:\n'
|
1023 |
+
' - CPU processing is significantly slower than GPU processing.\n'
|
1024 |
+
' - Only Nvidia GPUs can be used for GPU processing.'
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
IS_TIME_CORRECTION_HELP = ('When checked, the output will retain the original BPM of the input.')
|
1028 |
+
SAVE_STEM_ONLY_HELP = 'Allows the user to save only the selected stem.'
|
1029 |
+
IS_NORMALIZATION_HELP = 'Normalizes output to prevent clipping.'
|
1030 |
+
IS_CUDA_SELECT_HELP = "If you have more than one GPU, you can pick which one to use for processing."
|
1031 |
+
CROP_SIZE_HELP = '**Only compatible with select models only!**\n\n Setting should match training crop-size value. Leave as is if unsure.'
|
1032 |
+
IS_TTA_HELP = ('This option performs Test-Time-Augmentation to improve the separation quality.\n\n'
|
1033 |
+
'Note: Having this selected will increase the time it takes to complete a conversion')
|
1034 |
+
IS_POST_PROCESS_HELP = ('This option can potentially identify leftover instrumental artifacts within the vocal outputs. \nThis option may improve the separation of some songs.\n\n' +\
|
1035 |
+
'Note: Selecting this option can adversely affect the conversion process, depending on the track. Because of this, it is only recommended as a last resort.')
|
1036 |
+
IS_HIGH_END_PROCESS_HELP = 'The application will mirror the missing frequency range of the output.'
|
1037 |
+
SHIFTS_HELP = ('Performs multiple predictions with random shifts of the input and averages them.\n\n'
|
1038 |
+
'• The higher number of shifts, the longer the prediction will take. \n- Not recommended unless you have a GPU.')
|
1039 |
+
OVERLAP_HELP = ('• This option controls the amount of overlap between prediction windows.\n'
|
1040 |
+
' - Higher values can provide better results, but will lead to longer processing times.\n'
|
1041 |
+
' - You can choose between 0.001-0.999')
|
1042 |
+
MDX_OVERLAP_HELP = ('• This option controls the amount of overlap between prediction windows.\n'
|
1043 |
+
' - Higher values can provide better results, but will lead to longer processing times.\n'
|
1044 |
+
' - For Non-MDX23C models: You can choose between 0.001-0.999')
|
1045 |
+
OVERLAP_23_HELP = ('• This option controls the amount of overlap between prediction windows.\n'
|
1046 |
+
' - Higher values can provide better results, but will lead to longer processing times.')
|
1047 |
+
IS_SEGMENT_DEFAULT_HELP = '• The segment size is set based on the value provided in a chosen model\'s associated \nconfig file (yaml).'
|
1048 |
+
IS_SPLIT_MODE_HELP = '• Enables \"Segments\". \n• Deselecting this option is only recommended for those with powerful PCs.'
|
1049 |
+
IS_DEMUCS_COMBINE_STEMS_HELP = 'The application will create the secondary stem by combining the remaining stems \ninstead of inverting the primary stem with the mixture.'
|
1050 |
+
COMPENSATE_HELP = 'Compensates the audio of the primary stems to allow for a better secondary stem.'
|
1051 |
+
IS_DENOISE_HELP = ('• Standard: This setting reduces the noise created by MDX-Net models.\n'
|
1052 |
+
' - This option only reduces noise in non-MDX23 models.\n'
|
1053 |
+
'• Denoise Model: This setting employs a special denoise model to eliminate noise produced by any MDX-Net model.\n'
|
1054 |
+
' - This option works on all MDX-Net models.\n'
|
1055 |
+
' - You must have the "UVR-DeNoise-Lite" VR Arch model installed to use this option.\n'
|
1056 |
+
'• Please Note: Both options will increase separation time.')
|
1057 |
+
|
1058 |
+
VOC_SPLIT_MODEL_SELECT_HELP = '• Select a model from the list of lead and backing vocal models to run through vocal stems automatically.'
|
1059 |
+
IS_VOC_SPLIT_INST_SAVE_SELECT_HELP = '• When activated, you will receive extra instrumental outputs that include: one with just the lead vocals and another with only the backing vocals.'
|
1060 |
+
IS_VOC_SPLIT_MODEL_SELECT_HELP = ('• When activated, this option auto-processes generated vocal stems, using either a karaoke model to remove lead vocals or another to remove backing vocals.\n'
|
1061 |
+
' - This option splits the vocal track into two separate parts: lead vocals and backing vocals, providing two extra vocal outputs.\n'
|
1062 |
+
' - The results will be organized in the same way, whether you use a karaoke model or a background vocal model.\n'
|
1063 |
+
' - This option does not work in ensemble mode at this time.')
|
1064 |
+
IS_DEVERB_OPT_HELP = ('• Select the vocal type you wish to deverb automatically.\n'
|
1065 |
+
' - Example: Choosing "Lead Vocals Only" will only remove reverb from a lead vocal stem.')
|
1066 |
+
IS_DEVERB_VOC_HELP = ('• This option removes reverb from a vocal stem.\n'
|
1067 |
+
' - You must have the "UVR-DeEcho-DeReverb" VR Arch model installed to use this option.\n'
|
1068 |
+
' - This option does not work in ensemble mode at this time.')
|
1069 |
+
IS_FREQUENCY_MATCH_HELP = 'Matches the frequency cut-off of the primary stem to that of the secondary stem.'
|
1070 |
+
CLEAR_CACHE_HELP = 'Clears settings for unrecognized models chosen by the user.'
|
1071 |
+
IS_SAVE_ALL_OUTPUTS_ENSEMBLE_HELP = 'If enabled, all individual ensemble-generated outputs are retained.'
|
1072 |
+
IS_APPEND_ENSEMBLE_NAME_HELP = 'When enabled, the ensemble name is added to the final output.'
|
1073 |
+
IS_WAV_ENSEMBLE_HELP = (
|
1074 |
+
'Processes ensemble algorithms with waveforms instead of spectrograms when activated:\n'
|
1075 |
+
'• Might lead to increased distortion.\n'
|
1076 |
+
'• Waveform ensembling is faster than spectrogram ensembling.'
|
1077 |
+
)
|
1078 |
+
DONATE_HELP = 'Opens official UVR "Buy Me a Coffee" external link for project donations!'
|
1079 |
+
IS_INVERT_SPEC_HELP = (
|
1080 |
+
'Potentially enhances the secondary stem quality:\n'
|
1081 |
+
'• Inverts primary stem using spectrograms, instead of waveforms.\n'
|
1082 |
+
'• Slightly slower inversion method.'
|
1083 |
+
)
|
1084 |
+
IS_TESTING_AUDIO_HELP = 'Appends a 10-digit number to saved files to avoid accidental overwrites.'
|
1085 |
+
IS_MODEL_TESTING_AUDIO_HELP = 'Appends the model name to outputs for comparison across different models.'
|
1086 |
+
IS_ACCEPT_ANY_INPUT_HELP = (
|
1087 |
+
'Allows all types of inputs when enabled, even non-audio formats.\n'
|
1088 |
+
'For experimental use only. Not recommended for regular use.'
|
1089 |
+
)
|
1090 |
+
IS_TASK_COMPLETE_HELP = 'Plays a chime upon process completion or failure when activated.'
|
1091 |
+
DELETE_YOUR_SETTINGS_HELP = (
|
1092 |
+
'Contains your saved settings. Confirmation will be requested before deleting a selected setting.'
|
1093 |
+
)
|
1094 |
+
SET_STEM_NAME_HELP = 'Select the primary stem for the given model.'
|
1095 |
+
IS_CREATE_MODEL_FOLDER_HELP = ('Two new directories will be generated for the outputs in the export directory after each conversion.\n\n'
|
1096 |
+
'• Example: \n'
|
1097 |
+
'─ Export Directory\n'
|
1098 |
+
' └── First Directory (Named after the model)\n'
|
1099 |
+
' └── Second Directory (Named after the track)\n'
|
1100 |
+
' └── Output File(s)')
|
1101 |
+
MDX_DIM_T_SET_HELP = INTERNAL_MODEL_ATT
|
1102 |
+
MDX_DIM_F_SET_HELP = INTERNAL_MODEL_ATT
|
1103 |
+
|
1104 |
+
MDX_N_FFT_SCALE_SET_HELP = 'Specify the N_FFT size used during model training.'
|
1105 |
+
POPUP_COMPENSATE_HELP = (
|
1106 |
+
f'Select the appropriate volume compensation for the chosen model.\n'
|
1107 |
+
f'Reminder: {COMPENSATE_HELP}'
|
1108 |
+
)
|
1109 |
+
VR_MODEL_PARAM_HELP = 'Select the required parameters to run the chosen model.'
|
1110 |
+
CHOSEN_ENSEMBLE_HELP = (
|
1111 |
+
'Default Ensemble Selections:\n'
|
1112 |
+
'• Save the current ensemble configuration.\n'
|
1113 |
+
'• Clear all selected models.\n'
|
1114 |
+
'Note: You can also select previously saved ensembles.'
|
1115 |
+
)
|
1116 |
+
CHOSEN_PROCESS_METHOD_HELP = (
|
1117 |
+
'Choose a Processing Method:\n'
|
1118 |
+
'Select from various AI networks and algorithms to process your track:\n'
|
1119 |
+
'\n'
|
1120 |
+
'• VR Architecture: Uses magnitude spectrograms for source separation.\n'
|
1121 |
+
'• MDX-Net: Employs a Hybrid Spectrogram network for source separation.\n'
|
1122 |
+
'• Demucs v3: Also utilizes a Hybrid Spectrogram network for source separation.\n'
|
1123 |
+
'• Ensemble Mode: Combine results from multiple models and networks for optimal results.\n'
|
1124 |
+
'• Audio Tools: Additional utilities for added convenience.'
|
1125 |
+
)
|
1126 |
+
|
1127 |
+
INPUT_FOLDER_ENTRY_HELP = (
|
1128 |
+
'Select Input:\n'
|
1129 |
+
'Choose the audio file(s) you want to process.'
|
1130 |
+
)
|
1131 |
+
INPUT_FOLDER_ENTRY_HELP_2 = (
|
1132 |
+
'Input Option Menu:\n'
|
1133 |
+
'Click to access the input option menu.'
|
1134 |
+
)
|
1135 |
+
OUTPUT_FOLDER_ENTRY_HELP = (
|
1136 |
+
'Select Output:\n'
|
1137 |
+
'Choose the directory where the processed files will be saved.'
|
1138 |
+
)
|
1139 |
+
INPUT_FOLDER_BUTTON_HELP = (
|
1140 |
+
'Open Input Folder Button:\n'
|
1141 |
+
'Open the directory containing the selected input audio file(s).'
|
1142 |
+
)
|
1143 |
+
OUTPUT_FOLDER_BUTTON_HELP = (
|
1144 |
+
'Open Output Folder Button:\n'
|
1145 |
+
'Open the selected output folder.'
|
1146 |
+
)
|
1147 |
+
CHOOSE_MODEL_HELP = (
|
1148 |
+
'Each processing method has its own set of options and models.\n'
|
1149 |
+
'Choose the model associated with the selected processing method here.'
|
1150 |
+
)
|
1151 |
+
FORMAT_SETTING_HELP = 'Save Outputs As: '
|
1152 |
+
SECONDARY_MODEL_ACTIVATE_HELP = (
|
1153 |
+
'When enabled, the application will perform an additional inference using the selected model(s) above.'
|
1154 |
+
)
|
1155 |
+
SECONDARY_MODEL_HELP = (
|
1156 |
+
'Choose the Secondary Model:\n'
|
1157 |
+
'Select the secondary model associated with the stem you want to process with the current method.'
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
INPUT_SEC_FIELDS_HELP = (
|
1161 |
+
'Right click here to choose your inputs!'
|
1162 |
+
)
|
1163 |
+
|
1164 |
+
SECONDARY_MODEL_SCALE_HELP = ('The scale determines how the final audio outputs will be averaged between the primary and secondary models.\n\nFor example:\n\n'
|
1165 |
+
'• 10% - 10 percent of the main model result will be factored into the final result.\n'
|
1166 |
+
'• 50% - The results from the main and secondary models will be averaged evenly.\n'
|
1167 |
+
'• 90% - 90 percent of the main model result will be factored into the final result.')
|
1168 |
+
PRE_PROC_MODEL_ACTIVATE_HELP = (
|
1169 |
+
'When enabled, the application will use the selected model to isolate the instrumental stem.\n'
|
1170 |
+
'Subsequently, all non-vocal stems will be extracted from this generated instrumental.\n'
|
1171 |
+
'\n'
|
1172 |
+
'Key Points:\n'
|
1173 |
+
'• This feature can significantly reduce vocal bleed in non-vocal stems.\n'
|
1174 |
+
'• Available exclusively in the Demucs tool.\n'
|
1175 |
+
'• Compatible only with non-vocal and non-instrumental stem outputs.\n'
|
1176 |
+
'• Expect an increase in total processing time.\n'
|
1177 |
+
'• Only the VR or MDX-Net Vocal Instrumental/Vocals models can be chosen for this process.'
|
1178 |
+
)
|
1179 |
+
|
1180 |
+
AUDIO_TOOLS_HELP = (
|
1181 |
+
'Select from various audio tools to process your track:\n'
|
1182 |
+
'\n'
|
1183 |
+
'• Manual Ensemble: Requires 2 or more selected files as inputs. This allows tracks to be processed using the algorithms from Ensemble Mode.\n'
|
1184 |
+
'• Time Stretch: Adjust the playback speed of the selected inputs to be faster or slower.\n'
|
1185 |
+
'• Change Pitch: Modify the pitch of the selected inputs.\n'
|
1186 |
+
'• Align Inputs: Choose 2 audio file and the application will align them and provide the difference in alignment.\n'
|
1187 |
+
' - This tool provides similar functionality to "Utagoe."\n'
|
1188 |
+
' - Primary Audio: This is usually a mixture.\n'
|
1189 |
+
' - Secondary Audio: This is usually an instrumental.\n'
|
1190 |
+
'• Matchering: Choose 2 audio files. The matchering algorithm will master the target audio to have the same RMS, FR, peak amplitude, and stereo width as the reference audio.'
|
1191 |
+
)
|
1192 |
+
|
1193 |
+
PRE_PROC_MODEL_INST_MIX_HELP = 'When enabled, the application will generate a third output without the selected stem and vocals.'
|
1194 |
+
MODEL_SAMPLE_MODE_HELP = ('Allows the user to process only part of a track to sample settings or a model without running a full conversion.\n\nNotes:\n\n'
|
1195 |
+
'• The number in the parentheses is the current number of seconds the generated sample will be.\n'
|
1196 |
+
'• You can choose the number of seconds to extract from the track in the \"Additional Settings\" menu.')
|
1197 |
+
|
1198 |
+
POST_PROCESS_THREASHOLD_HELP = ('Allows the user to control the intensity of the Post_process option.\n\nNotes:\n\n'
|
1199 |
+
'• Higher values potentially remove more artifacts. However, bleed might increase.\n'
|
1200 |
+
'• Lower values limit artifact removal.')
|
1201 |
+
|
1202 |
+
BATCH_SIZE_HELP = ('Specify the number of batches to be processed at a time.\n\nNotes:\n\n'
|
1203 |
+
'• Higher values mean more RAM usage but slightly faster processing times.\n'
|
1204 |
+
'• Lower values mean less RAM usage but slightly longer processing times.\n'
|
1205 |
+
'• Batch size value has no effect on output quality.')
|
1206 |
+
|
1207 |
+
VR_MODEL_NOUT_HELP = ""
|
1208 |
+
VR_MODEL_NOUT_LSTM_HELP = ""
|
1209 |
+
|
1210 |
+
IS_PHASE_HELP = 'Select the phase for the secondary audio.\n• Note: Using the "Automatic" option is strongly recommended.'
|
1211 |
+
IS_ALIGN_TRACK_HELP = 'Enable this to save the secondary track once aligned.'
|
1212 |
+
IS_MATCH_SILENCE_HELP = (
|
1213 |
+
'Aligns the initial silence of the secondary audio with the primary audio.\n'
|
1214 |
+
'• Note: Avoid using this option if the primary audio begins solely with vocals.'
|
1215 |
+
)
|
1216 |
+
IS_MATCH_SPEC_HELP = 'Align the secondary audio based on the primary audio\'s spectrogram.\n• Note: This may enhance alignment in specific cases.'
|
1217 |
+
|
1218 |
+
TIME_WINDOW_ALIGN_HELP = (
|
1219 |
+
'This setting determines the window size for alignment analysis, especially for pairs with minor timing variations:\n'
|
1220 |
+
'\n'
|
1221 |
+
'• None: Disables time window analysis.\n'
|
1222 |
+
'• 1: Analyzes pair by 0.0625-second windows.\n'
|
1223 |
+
'• 2: Analyzes pair by 0.125-second windows.\n'
|
1224 |
+
'• 3: Analyzes pair by 0.25-second windows.\n'
|
1225 |
+
'• 4: Analyzes pair by 0.50-second windows.\n'
|
1226 |
+
'• 5: Analyzes pair by 0.75-second windows.\n'
|
1227 |
+
'• 6: Analyzes pair by 1-second windows.\n'
|
1228 |
+
'• 7: Analyzes pair by 2-second windows.\n'
|
1229 |
+
'\n'
|
1230 |
+
'Shifts Options:\n'
|
1231 |
+
'• Low: Cycles through 0.0625 and 0.5-second windows to find an optimal match.\n'
|
1232 |
+
'• Medium: Cycles through 0.0625, 0.125, and 0.5-second windows to find an optimal match.\n'
|
1233 |
+
'• High: Cycles through 0.0625, 0.125, 0.25, and 0.5-second windows to find an optimal match.\n'
|
1234 |
+
'\n'
|
1235 |
+
'Important Points to Consider:\n'
|
1236 |
+
' - Using the "Shifts" option may require more processing time and might not guarantee better results.\n'
|
1237 |
+
' - Opting for smaller analysis windows can increase processing times.\n'
|
1238 |
+
' - The best settings are likely to vary based on the specific tracks being processed.'
|
1239 |
+
)
|
1240 |
+
INTRO_ANALYSIS_ALIGN_HELP = (
|
1241 |
+
'This setting determines the portion of the audio input to be analyzed for initial alignment.\n'
|
1242 |
+
'\n'
|
1243 |
+
'• Default: Analyzes 10% (or 1/10th) of the audio\'s total length.\n'
|
1244 |
+
'• 1: Analyzes 12.5% (or 1/8th) of the audio\'s total length.\n'
|
1245 |
+
'• 2: Analyzes 16.67% (or 1/6th) of the audio\'s total length.\n'
|
1246 |
+
'• 3: Analyzes 25% (or 1/4th) of the audio\'s total length.\n'
|
1247 |
+
'• 4: Analyzes 50% (or half) of the audio\'s total length.\n'
|
1248 |
+
'\n'
|
1249 |
+
'Shifts Options:\n'
|
1250 |
+
'• Low: Cycles through 2 intro analysis values.\n'
|
1251 |
+
'• Medium: Cycles through 3 intro analysis values.\n'
|
1252 |
+
'• High: Cycles through 5 intro analysis values.\n'
|
1253 |
+
'\n'
|
1254 |
+
'Important Points to Consider:\n'
|
1255 |
+
' - Using the "Shifts" option will require more processing time and might not guarantee better results.\n'
|
1256 |
+
' - Optimal settings may vary depending on the specific tracks being processed.'
|
1257 |
+
)
|
1258 |
+
|
1259 |
+
VOLUME_ANALYSIS_ALIGN_HELP = (
|
1260 |
+
'This setting specifies the volume adjustments to be made on the secondary input:\n'
|
1261 |
+
'\n'
|
1262 |
+
'• None: No volume adjustments are made.\n'
|
1263 |
+
'• Low: Analyzes the audio within a 4dB range, adjusting in 1dB increments.\n'
|
1264 |
+
'• Medium: Analyzes the audio within a 6dB range, adjusting in 1dB increments.\n'
|
1265 |
+
'• High: Analyzes the audio within a 6dB range, adjusting in 0.5dB increments.\n'
|
1266 |
+
'• Very High: Analyzes the audio within a 10dB range, adjusting in 0.5dB increments.\n'
|
1267 |
+
'\n'
|
1268 |
+
'Important Points to Consider:\n'
|
1269 |
+
' - Selecting more extensive analysis options (e.g., High, Very High) will lead to longer processing times.\n'
|
1270 |
+
' - Optimal settings might vary based on the specific tracks being processed.'
|
1271 |
+
)
|
1272 |
+
|
1273 |
+
PHASE_SHIFTS_ALIGN_HELP = (
|
1274 |
+
'This setting specifies the phase adjustments to be made on the secondary input:\n'
|
1275 |
+
'\n'
|
1276 |
+
'Shifts Options:\n'
|
1277 |
+
'• None: No phase adjustments are made.\n'
|
1278 |
+
'• Very Low: Analyzes the audio within range of 2 different phase positions.\n'
|
1279 |
+
'• Low: Analyzes the audio within range of 4 different phase positions.\n'
|
1280 |
+
'• Medium: Analyzes the audio within range of 8 different phase positions.\n'
|
1281 |
+
'• High: Analyzes the audio within range of 18 different phase positions.\n'
|
1282 |
+
'• Very High: Analyzes the audio within range of 36 different phase positions.\n'
|
1283 |
+
'• Maximum: Analyzes the audio in all 360 phase positions.\n'
|
1284 |
+
'\n'
|
1285 |
+
'Important Points to Consider:\n'
|
1286 |
+
' - This option only works with time correction.\n'
|
1287 |
+
' - This option can be helpful if one of the inputs were from an analog source.\n'
|
1288 |
+
' - Selecting more extensive analysis options (e.g., High, Very High) will lead to longer processing times.\n'
|
1289 |
+
' - Selecting "Maximum" can take hours to process.\n'
|
1290 |
+
' - Optimal settings might vary based on the specific tracks being processed.'
|
1291 |
+
)
|
1292 |
+
|
1293 |
+
# Warning Messages
|
1294 |
+
STORAGE_ERROR = 'Insufficient Storage', 'There is not enough storage on main drive to continue. Your main drive must have at least 3 GB\'s of storage in order for this application function properly. \n\nPlease ensure your main drive has at least 3 GB\'s of storage and try again.\n\n'
|
1295 |
+
STORAGE_WARNING = 'Available Storage Low', 'Your main drive is running low on storage. Your main drive must have at least 3 GB\'s of storage in order for this application function properly.\n\n'
|
1296 |
+
CONFIRM_WARNING = '\nAre you sure you wish to continue?'
|
1297 |
+
PROCESS_FAILED = 'Process failed, please see error log\n'
|
1298 |
+
EXIT_PROCESS_ERROR = 'Active Process', 'Please stop the active process or wait for it to complete before you exit.'
|
1299 |
+
EXIT_HALTED_PROCESS_ERROR = 'Halting Process', 'Please wait for the application to finish halting the process before exiting.'
|
1300 |
+
EXIT_DOWNLOAD_ERROR = 'Active Download', 'Please stop the download or wait for it to complete before you exit.'
|
1301 |
+
SET_TO_DEFAULT_PROCESS_ERROR = 'Active Process', 'You cannot reset all of the application settings during an active process.'
|
1302 |
+
SET_TO_ANY_PROCESS_ERROR = 'Active Process', 'You cannot reset the application settings during an active process.'
|
1303 |
+
RESET_ALL_TO_DEFAULT_WARNING = 'Reset Settings Confirmation', 'All application settings will be set to factory default.\n\nAre you sure you wish to continue?'
|
1304 |
+
AUDIO_VERIFICATION_CHECK = lambda i, e:f'++++++++++++++++++++++++++++++++++++++++++++++++++++\n\nBroken File Removed: \n\n{i}\n\nError Details:\n\n{e}\n++++++++++++++++++++++++++++++++++++++++++++++++++++'
|
1305 |
+
INVALID_ONNX_MODEL_ERROR = 'Invalid Model', 'The file selected is not a valid MDX-Net model. Please see the error log for more information.'
|
1306 |
+
INVALID_PARAM_MODEL_ERROR = 'Select Model Param', 'Please choose a model param or click \'Cancel\'.'
|
1307 |
+
UNRECOGNIZED_MODEL = 'Unrecognized Model Detected', ' is an unrecognized model.\n\n' + \
|
1308 |
+
'Would you like to select the correct parameters before continuing?'
|
1309 |
+
STOP_PROCESS_CONFIRM = 'Confirmation', 'You are about to stop all active processes.\n\nAre you sure you wish to continue?'
|
1310 |
+
NO_ENSEMBLE_SELECTED = 'No Models Selected', 'Please select ensemble and try again.'
|
1311 |
+
PICKLE_CORRU = 'File Corrupted', 'Unable to load this ensemble.\n\n' + \
|
1312 |
+
'Would you like to remove this ensemble from your list?'
|
1313 |
+
DELETE_ENS_ENTRY = 'Confirm Removal', 'Are you sure you want to remove this entry?'
|
1314 |
+
|
1315 |
+
# Separation Text
|
1316 |
+
LOADING_MODEL = 'Loading model...'
|
1317 |
+
INFERENCE_STEP_1 = 'Running inference...'
|
1318 |
+
INFERENCE_STEP_1_SEC = 'Running inference (secondary model)...'
|
1319 |
+
INFERENCE_STEP_1_4_STEM = lambda stem:f'Running inference (secondary model for {stem})...'
|
1320 |
+
INFERENCE_STEP_1_PRE = 'Running inference (pre-process model)...'
|
1321 |
+
INFERENCE_STEP_1_VOC_S = 'Splitting vocals...'
|
1322 |
+
INFERENCE_STEP_2_PRE = lambda pm, m:f'Loading pre-process model ({pm}: {m})...'
|
1323 |
+
INFERENCE_STEP_2_SEC = lambda pm, m:f'Loading secondary model ({pm}: {m})...'
|
1324 |
+
INFERENCE_STEP_2_VOC_S = lambda pm, m:f'Loading vocal splitter model ({pm}: {m})...'
|
1325 |
+
INFERENCE_STEP_2_SEC_CACHED_MODOEL = lambda pm, m:f'Secondary model ({pm}: {m}) cache loaded.\n'
|
1326 |
+
INFERENCE_STEP_2_PRE_CACHED_MODOEL = lambda pm, m:f'Pre-process model ({pm}: {m}) cache loaded.\n'
|
1327 |
+
INFERENCE_STEP_2_SEC_CACHED = 'Loading cached secondary model source(s)... Done!\n'
|
1328 |
+
INFERENCE_STEP_2_PRIMARY_CACHED = ' Model cache loaded.\n'
|
1329 |
+
INFERENCE_STEP_2 = 'Inference complete.'
|
1330 |
+
INFERENCE_STEP_DEVERBING = ' Deverbing...'
|
1331 |
+
SAVING_STEM = 'Saving ', ' stem...'
|
1332 |
+
SAVING_ALL_STEMS = 'Saving all stems...'
|
1333 |
+
ENSEMBLING_OUTPUTS = 'Ensembling outputs...'
|
1334 |
+
DONE = ' Done!\n'
|
1335 |
+
ENSEMBLES_SAVED = 'Ensembled outputs saved!\n\n'
|
1336 |
+
|
1337 |
+
#Additional Text
|
1338 |
+
CHOOSE_PROC_METHOD_MAIN_LABEL = 'CHOOSE PROCESS METHOD'
|
1339 |
+
SELECT_SAVED_SETTINGS_MAIN_LABEL = 'SELECT SAVED SETTINGS'
|
1340 |
+
CHOOSE_MDX_MODEL_MAIN_LABEL = 'CHOOSE MDX-NET MODEL'
|
1341 |
+
BATCHES_MDX_MAIN_LABEL = 'BATCH SIZE'
|
1342 |
+
VOL_COMP_MDX_MAIN_LABEL = 'VOLUME COMPENSATION'
|
1343 |
+
SEGMENT_MDX_MAIN_LABEL = 'SEGMENT SIZE'
|
1344 |
+
SELECT_VR_MODEL_MAIN_LABEL = 'CHOOSE VR MODEL'
|
1345 |
+
AGGRESSION_SETTING_MAIN_LABEL = 'AGGRESSION SETTING'
|
1346 |
+
WINDOW_SIZE_MAIN_LABEL = 'WINDOW SIZE'
|
1347 |
+
CHOOSE_DEMUCS_MODEL_MAIN_LABEL = 'CHOOSE DEMUCS MODEL'
|
1348 |
+
CHOOSE_STEMS_MAIN_LABEL = 'CHOOSE STEM(S)'
|
1349 |
+
CHOOSE_SEGMENT_MAIN_LABEL = 'SEGMENT'
|
1350 |
+
ENSEMBLE_OPTIONS_MAIN_LABEL = 'ENSEMBLE OPTIONS'
|
1351 |
+
CHOOSE_MAIN_PAIR_MAIN_LABEL = 'MAIN STEM PAIR'
|
1352 |
+
CHOOSE_ENSEMBLE_ALGORITHM_MAIN_LABEL = 'ENSEMBLE ALGORITHM'
|
1353 |
+
AVAILABLE_MODELS_MAIN_LABEL = 'AVAILABLE MODELS'
|
1354 |
+
CHOOSE_AUDIO_TOOLS_MAIN_LABEL = 'CHOOSE AUDIO TOOL'
|
1355 |
+
CHOOSE_MANUAL_ALGORITHM_MAIN_LABEL = 'CHOOSE ALGORITHM'
|
1356 |
+
CHOOSE_RATE_MAIN_LABEL = 'RATE'
|
1357 |
+
CHOOSE_SEMITONES_MAIN_LABEL = 'SEMITONES'
|
1358 |
+
GPU_CONVERSION_MAIN_LABEL = 'GPU Conversion'
|
1359 |
+
CHANGE_LOG_HEADER = lambda patch:f"Patch Version:\n\n{patch}"
|
1360 |
+
INVALID_INPUT_E = ' Invalid input! '
|
1361 |
+
LB_UP = "Move Selection Up"
|
1362 |
+
LB_DOWN = "Move Selection Down"
|
1363 |
+
LB_CLEAR = "Clear Box"
|
1364 |
+
LB_MOVE_OVER_P = "Move Selection to Secondary List"
|
1365 |
+
LB_MOVE_OVER_S = "Move Selection to Primary List"
|
1366 |
+
FILE_ONE_MAIN_LABEL = "PRIMARY AUDIO"
|
1367 |
+
FILE_TWO_MAIN_LABEL = "SECONDARY AUDIO"
|
1368 |
+
FILE_ONE_MATCH_MAIN_LABEL = "TARGET AUDIO"
|
1369 |
+
FILE_TWO_MATCH_MAIN_LABEL = "REFERENCE AUDIO"
|
1370 |
+
TIME_WINDOW_MAIN_LABEL = "TIME ADJUSTMENT"
|
1371 |
+
INTRO_ANALYSIS_MAIN_LABEL = "INTRO ANALYSIS"
|
1372 |
+
VOLUME_ADJUSTMENT_MAIN_LABEL = "VOLUME ADJUSTMENT"
|
1373 |
+
SELECT_INPUTS = "Select Input(s)"
|
1374 |
+
SELECTED_INPUTS = 'Selected Inputs'
|
1375 |
+
WIDEN_BOX = 'Widen Box'
|
1376 |
+
CONFIRM_ENTRIES = 'Confirm Entries'
|
1377 |
+
CLOSE_WINDOW = 'Close Window'
|
1378 |
+
DUAL_AUDIO_PROCESSING = 'Dual Audio Batch Processing'
|
1379 |
+
CANCEL_TEXT = "Cancel"
|
1380 |
+
CONFIRM_TEXT = "Confirm"
|
1381 |
+
SELECT_MODEL_TEXT = 'Select Model'
|
1382 |
+
NONE_SELECTED = 'None Selected'
|
1383 |
+
SAVE_TEXT = 'Save'
|
1384 |
+
OVERLAP_TEXT = 'Overlap'
|
1385 |
+
ACCEPT_ANY_INPUT_TEXT = 'Accept Any Input'
|
1386 |
+
ACTIVATE_PRE_PROCESS_MODEL_TEXT = 'Activate Pre-process Model'
|
1387 |
+
ACTIVATE_SECONDARY_MODEL_TEXT = 'Activate Secondary Model'
|
1388 |
+
ADDITIONAL_MENUS_INFORMATION_TEXT = 'Additional Menus & Information'
|
1389 |
+
ADDITIONAL_SETTINGS_TEXT = 'Additional Settings'
|
1390 |
+
ADVANCED_ALIGN_TOOL_OPTIONS_TEXT = 'Advanced Align Tool Options'
|
1391 |
+
ADVANCED_DEMUCS_OPTIONS_TEXT = 'Advanced Demucs Options'
|
1392 |
+
ADVANCED_ENSEMBLE_OPTIONS_TEXT = 'Advanced Ensemble Options'
|
1393 |
+
ADVANCED_MDXNET23_OPTIONS_TEXT = 'Advanced MDX-NET23 Options'
|
1394 |
+
ADVANCED_MDXNET_OPTIONS_TEXT = 'Advanced MDX-Net Options'
|
1395 |
+
ADVANCED_OPTION_MENU_TEXT = 'Advanced Option Menu'
|
1396 |
+
ADVANCED_VR_OPTIONS_TEXT = 'Advanced VR Options'
|
1397 |
+
AGGRESSION_SETTING_TEXT = 'Aggression Setting'
|
1398 |
+
APPEND_ENSEMBLE_NAME_TEXT = 'Append Ensemble Name'
|
1399 |
+
APPLICATION_DOWNLOAD_CENTER_TEXT = 'Application Download Center'
|
1400 |
+
APPLICATION_UPDATES_TEXT = 'Application Updates'
|
1401 |
+
AUDIO_FORMAT_SETTINGS_TEXT = 'Audio Format Settings'
|
1402 |
+
BALANCE_VALUE_TEXT = 'Balance Value'
|
1403 |
+
BATCH_SIZE_TEXT = 'Batch Size'
|
1404 |
+
BV_MODEL_TEXT = 'BV Model'
|
1405 |
+
CHANGE_MODEL_DEFAULT_TEXT = 'Change Model Default'
|
1406 |
+
CHANGE_MODEL_DEFAULTS_TEXT = 'Change Model Defaults'
|
1407 |
+
CHANGE_PARAMETERS_TEXT = 'Change Parameters'
|
1408 |
+
CHOOSE_ADVANCED_MENU_TEXT = 'Choose Advanced Menu'
|
1409 |
+
CHOOSE_MODEL_PARAM_TEXT = 'Choose Model Param'
|
1410 |
+
CLEAR_AUTOSET_CACHE_TEXT = 'Clear Auto-Set Cache'
|
1411 |
+
COMBINE_STEMS_TEXT = 'Combine Stems'
|
1412 |
+
CONFIRM_UPDATE_TEXT = 'Confirm Update'
|
1413 |
+
COPIED_TEXT = 'Copied!'
|
1414 |
+
COPY_ALL_TEXT_TEXT = 'Copy All Text'
|
1415 |
+
DEFINED_PARAMETERS_DELETED_TEXT = 'Defined Parameters Deleted'
|
1416 |
+
DELETE_PARAMETERS_TEXT = 'Delete Parameters'
|
1417 |
+
DELETE_USER_SAVED_SETTING_TEXT = 'Delete User Saved Setting'
|
1418 |
+
DEMUCS_TEXT = 'Demucs'
|
1419 |
+
DENOISE_OUTPUT_TEXT = 'Denoise Output'
|
1420 |
+
DEVERB_VOCALS_TEXT = 'Deverb Vocals'
|
1421 |
+
DONE_TEXT = 'Done'
|
1422 |
+
DOWNLOAD_CENTER_TEXT = 'Download Center'
|
1423 |
+
DOWNLOAD_CODE_TEXT = 'Download Code'
|
1424 |
+
DOWNLOAD_LINKS_TEXT = 'Download Link(s)'
|
1425 |
+
DOWNLOAD_UPDATE_IN_APPLICATION_TEXT = 'Download Update in Application'
|
1426 |
+
ENABLE_HELP_HINTS_TEXT = 'Enable Help Hints'
|
1427 |
+
ENABLE_TTA_TEXT = 'Enable TTA'
|
1428 |
+
ENABLE_VOCAL_SPLIT_MODE_TEXT = 'Enable Vocal Split Mode'
|
1429 |
+
ENSEMBLE_NAME_TEXT = 'Ensemble Name'
|
1430 |
+
ENSEMBLE_WAVFORMS_TEXT = 'Ensemble Wavforms'
|
1431 |
+
ERROR_CONSOLE_TEXT = 'Error Console'
|
1432 |
+
GENERAL_MENU_TEXT = 'General Menu'
|
1433 |
+
GENERAL_PROCESS_SETTINGS_TEXT = 'General Process Settings'
|
1434 |
+
GENERATE_MODEL_FOLDER_TEXT = 'Generate Model Folder'
|
1435 |
+
HIGHEND_PROCESS_TEXT = 'High-End Process'
|
1436 |
+
INPUT_CODE_TEXT = 'Input Code'
|
1437 |
+
INPUT_STEM_NAME_TEXT = 'Input Stem Name'
|
1438 |
+
INPUT_UNIQUE_STEM_NAME_TEXT = 'Input Unique Stem Name'
|
1439 |
+
IS_INVERSE_STEM_TEXT = 'Is Inverse Stem'
|
1440 |
+
KARAOKE_MODEL_TEXT = 'Karaoke Model'
|
1441 |
+
MANUAL_DOWNLOADS_TEXT = 'Manual Downloads'
|
1442 |
+
MATCH_FREQ_CUTOFF_TEXT = 'Match Freq Cut-off'
|
1443 |
+
MDXNET_C_MODEL_PARAMETERS_TEXT = 'MDX-Net C Model Parameters'
|
1444 |
+
MDXNET_MODEL_SETTINGS_TEXT = 'MDX-Net Model Settings'
|
1445 |
+
MDXNET_TEXT = 'MDX-Net'
|
1446 |
+
MODEL_PARAMETERS_CHANGED_TEXT = 'Model Parameters Changed'
|
1447 |
+
MODEL_SAMPLE_MODE_SETTINGS_TEXT = 'Model Sample Mode Settings'
|
1448 |
+
MODEL_TEST_MODE_TEXT = 'Model Test Mode'
|
1449 |
+
MP3_BITRATE_TEXT = 'Mp3 Bitrate'
|
1450 |
+
NAME_SETTINGS_TEXT = 'Name Settings'
|
1451 |
+
NO_DEFINED_PARAMETERS_FOUND_TEXT = 'No Defined Parameters Found'
|
1452 |
+
NO_TEXT = 'No'
|
1453 |
+
NORMALIZE_OUTPUT_TEXT = 'Normalize Output'
|
1454 |
+
USE_OPENCL_TEXT = 'Use OpenCL'
|
1455 |
+
NOT_ENOUGH_MODELS_TEXT = 'Not Enough Models'
|
1456 |
+
NOTIFICATION_CHIMES_TEXT = 'Notification Chimes'
|
1457 |
+
OPEN_APPLICATION_DIRECTORY_TEXT = 'Open Application Directory'
|
1458 |
+
OPEN_LINK_TO_MODEL_TEXT = 'Open Link to Model'
|
1459 |
+
OPEN_MODEL_DIRECTORY_TEXT = 'Open Model Directory'
|
1460 |
+
OPEN_MODEL_FOLDER_TEXT = 'Open Model Folder'
|
1461 |
+
OPEN_MODELS_FOLDER_TEXT = 'Open Models Folder'
|
1462 |
+
PHASE_SHIFTS_TEXT = 'Phase Shifts'
|
1463 |
+
POST_PROCESS_TEXT = 'Post-Process'
|
1464 |
+
POST_PROCESS_THRESHOLD_TEXT = 'Post-process Threshold'
|
1465 |
+
PREPROCESS_MODEL_CHOOSE_TEXT = 'Pre-process Model'
|
1466 |
+
PRIMARY_STEM_TEXT = 'Primary Stem'
|
1467 |
+
REFRESH_LIST_TEXT = 'Refresh List'
|
1468 |
+
REMOVE_SAVED_ENSEMBLE_TEXT = 'Remove Saved Ensemble'
|
1469 |
+
REPORT_ISSUE_TEXT = 'Report Issue'
|
1470 |
+
RESET_ALL_SETTINGS_TO_DEFAULT_TEXT = 'Reset All Settings to Default'
|
1471 |
+
RESTART_APPLICATION_TEXT = 'Restart Application'
|
1472 |
+
SAMPLE_CLIP_DURATION_TEXT = 'Sample Clip Duration'
|
1473 |
+
SAVE_ALIGNED_TRACK_TEXT = 'Save Aligned Track'
|
1474 |
+
SAVE_ALL_OUTPUTS_TEXT = 'Save All Outputs'
|
1475 |
+
SAVE_CURRENT_ENSEMBLE_TEXT = 'Save Current Ensemble'
|
1476 |
+
SAVE_CURRENT_SETTINGS_TEXT = 'Save Current Settings'
|
1477 |
+
SAVE_INSTRUMENTAL_MIXTURE_TEXT = 'Save Instrumental Mixture'
|
1478 |
+
SAVE_SPLIT_VOCAL_INSTRUMENTALS_TEXT = 'Save Split Vocal Instrumentals'
|
1479 |
+
SECONDARY_MODEL_TEXT = 'Secondary Model'
|
1480 |
+
SECONDARY_PHASE_TEXT = 'Secondary Phase'
|
1481 |
+
SECONDS_TEXT = 'Seconds'
|
1482 |
+
SEGMENT_DEFAULT_TEXT = 'Segment Default'
|
1483 |
+
SEGMENT_SIZE_TEXT = 'Segment Size'
|
1484 |
+
SEGMENTS_TEXT = 'Segments'
|
1485 |
+
SELECT_DOWNLOAD_TEXT = 'Select Download'
|
1486 |
+
SELECT_MODEL_PARAM_TEXT = 'Select Model Param'
|
1487 |
+
SELECT_VOCAL_TYPE_TO_DEVERB_TEXT = 'Select Vocal Type to Deverb'
|
1488 |
+
SELECTED_MODEL_PLACEMENT_PATH_TEXT = 'Selected Model Placement Path'
|
1489 |
+
SETTINGS_GUIDE_TEXT = 'Settings Guide'
|
1490 |
+
SETTINGS_TEST_MODE_TEXT = 'Settings Test Mode'
|
1491 |
+
SHIFT_CONVERSION_PITCH_TEXT = 'Shift Conversion Pitch'
|
1492 |
+
SHIFTS_TEXT = 'Shifts'
|
1493 |
+
SILENCE_MATCHING_TEXT = 'Silence Matching'
|
1494 |
+
SPECIFY_MDX_NET_MODEL_PARAMETERS_TEXT = 'Specify MDX-Net Model Parameters'
|
1495 |
+
SPECIFY_PARAMETERS_TEXT = 'Specify Parameters'
|
1496 |
+
SPECIFY_VR_MODEL_PARAMETERS_TEXT = 'Specify VR Model Parameters'
|
1497 |
+
SPECTRAL_INVERSION_TEXT = 'Spectral Inversion'
|
1498 |
+
SPECTRAL_MATCHING_TEXT = 'Spectral Matching'
|
1499 |
+
SPLIT_MODE_TEXT = 'Split Mode'
|
1500 |
+
STEM_NAME_TEXT = 'Stem Name'
|
1501 |
+
STOP_DOWNLOAD_TEXT = 'Stop Download'
|
1502 |
+
SUPPORT_UVR_TEXT = 'Support UVR'
|
1503 |
+
TRY_MANUAL_DOWNLOAD_TEXT = 'Try Manual Download'
|
1504 |
+
UPDATE_FOUND_TEXT = 'Update Found'
|
1505 |
+
USER_DOWNLOAD_CODES_TEXT = 'User Download Codes'
|
1506 |
+
UVR_BUY_ME_A_COFFEE_LINK_TEXT = 'UVR \'Buy Me a Coffee\' Link'
|
1507 |
+
UVR_ERROR_LOG_TEXT = 'UVR Error Log'
|
1508 |
+
UVR_PATREON_LINK_TEXT = 'UVR Patreon Link'
|
1509 |
+
VOCAL_DEVERB_OPTIONS_TEXT = 'Vocal Deverb Options'
|
1510 |
+
VOCAL_SPLIT_MODE_OPTIONS_TEXT = 'Vocal Split Mode Options'
|
1511 |
+
VOCAL_SPLIT_OPTIONS_TEXT = 'Vocal Split Options'
|
1512 |
+
VOLUME_COMPENSATION_TEXT = 'Volume Compensation'
|
1513 |
+
VR_51_MODEL_TEXT = 'VR 5.1 Model'
|
1514 |
+
VR_ARCH_TEXT = 'VR Arch'
|
1515 |
+
WAV_TYPE_TEXT = 'Wav Type'
|
1516 |
+
CUDA_NUM_TEXT = 'GPU Device'
|
1517 |
+
WINDOW_SIZE_TEXT = 'Window Size'
|
1518 |
+
YES_TEXT = 'Yes'
|
1519 |
+
VERIFY_INPUTS_TEXT = 'Verify Inputs'
|
1520 |
+
AUDIO_INPUT_TOTAL_TEXT = 'Audio Input Total'
|
1521 |
+
MDX23C_ONLY_OPTIONS_TEXT = 'MDXNET23 Only Options'
|
1522 |
+
PROCESS_STARTING_TEXT = 'Process starting... '
|
1523 |
+
MISSING_MESS_TEXT = 'is missing or currupted.'
|
1524 |
+
SIMILAR_TEXT = "are the same."
|
1525 |
+
LOADING_VERSION_INFO_TEXT = 'Loading version information...'
|
1526 |
+
CHECK_FOR_UPDATES_TEXT = 'Check for Updates'
|
1527 |
+
INFO_UNAVAILABLE_TEXT = "Information unavailable."
|
1528 |
+
UPDATE_CONFIRMATION_TEXT = 'Are you sure you want to continue?\n\nThe application will need to be restarted.\n'
|
1529 |
+
BROKEN_OR_INCOM_TEXT = 'Broken or Incompatible File(s) Removed. Check Error Log for details.'
|
1530 |
+
BMAC_UVR_TEXT = 'UVR \"Buy Me a Coffee\" Link'
|
1531 |
+
MDX_MENU_WAR_TEXT = '(Leave this setting as is if you are unsure.)'
|
1532 |
+
NO_FILES_TEXT = 'No Files'
|
1533 |
+
CHOOSE_INPUT_TEXT = 'Choose Input'
|
1534 |
+
OPEN_INPUT_DIR_TEXT = 'Open Input Directory'
|
1535 |
+
BATCH_PROCESS_MENU_TEXT = 'Batch Process Menu'
|
1536 |
+
TEMP_FILE_DELETION_TEXT = 'Temp File Deletion'
|
1537 |
+
VOCAL_SPLITTER_OPTIONS_TEXT = 'Vocal Splitter Options'
|
1538 |
+
WAVEFORM_ENSEMBLE_TEXT = 'Waveform Ensemble'
|
1539 |
+
SELECT_INPUT_TEXT = 'Select Input'
|
1540 |
+
SELECT_OUTPUT_TEXT = 'Select Output'
|
1541 |
+
TIME_CORRECTION_TEXT = 'Time Correction'
|
1542 |
+
UVR_LIS_INFO_TEXT = 'UVR License Information'
|
1543 |
+
ADDITIONAL_RES_CREDITS_TEXT = 'Additional Resources & Credits'
|
1544 |
+
SAVE_INST_MIXTURE_TEXT = 'Save Instrumental Mixture'
|
1545 |
+
DOWNLOAD_UPDATE_IN_APP_TEXT = 'Download Update in Application'
|
1546 |
+
WAVE_TYPE_TEXT = 'WAVE TYPE'
|
1547 |
+
OPEN_LINK_TO_MODEL_TEXT = "Open Link to Model"
|
1548 |
+
OPEN_MODEL_DIRECTORY = "Open Model Directory"
|
1549 |
+
SELECTED_MODEL_PLACE_PATH_TEXT = 'Selected Model Placement Path'
|
1550 |
+
IS_INVERSE_STEM_TEXT = "Is Inverse Stem"
|
1551 |
+
INPUT_STEM_NAME_TEXT = "Input Stem Name"
|
1552 |
+
INPUT_UNIQUE_STEM_NAME_TEXT = "Input Unique Stem Name"
|
1553 |
+
DONE_MENU_TEXT = "Done"
|
1554 |
+
OK_TEXT = "Ok"
|
1555 |
+
ENSEMBLE_WARNING_NOT_ENOUGH_SHORT_TEXT = "Not Enough Models"
|
1556 |
+
ENSEMBLE_WARNING_NOT_ENOUGH_TEXT = "You must select 2 or more models to save an ensemble."
|
1557 |
+
NOT_ENOUGH_ERROR_TEXT = "Not enough files to process.\n"
|
1558 |
+
INVALID_FOLDER_ERROR_TEXT = 'Invalid Folder', 'Your given export path is not a valid folder!'
|
1559 |
+
|
1560 |
+
GET_DL_VIP_CODE_TEXT = ("Obtain codes by visiting one of the following links below."
|
1561 |
+
"\nFrom there you can donate, pledge, "
|
1562 |
+
"or just obatain the code!\n (Donations are not required to obtain VIP code)")
|
1563 |
+
CONFIRM_RESTART_TEXT = 'Restart Confirmation', 'This will restart the application and halt any running processes. Your current settings will be saved. \n\n Are you sure you wish to continue?'
|
1564 |
+
ERROR_LOADING_FILE_TEXT = 'Error Loading the Following File', 'Raw Error Details'
|
1565 |
+
LOADING_MODEL_TEXT = 'Loading model'
|
1566 |
+
FULL_APP_SET_TEXT = 'Full Application Settings'
|
1567 |
+
PROCESS_STARTING_TEXT = 'Process starting... '
|
1568 |
+
PROCESS_STOPPED_BY_USER = '\n\nProcess stopped by user.'
|
1569 |
+
NEW_UPDATE_FOUND_TEXT = lambda version:f"\n\nNew Update Found: {version}\n\nClick the update button in the \"Settings\" menu to download and install!"
|
1570 |
+
ROLL_BACK_TEXT = 'Click Here to Roll Back'
|
1571 |
+
|
1572 |
+
def secondary_stem(stem:str):
|
1573 |
+
"""Determines secondary stem"""
|
1574 |
+
|
1575 |
+
stem = stem if stem else NO_STEM
|
1576 |
+
|
1577 |
+
if stem in STEM_PAIR_MAPPER.keys():
|
1578 |
+
for key, value in STEM_PAIR_MAPPER.items():
|
1579 |
+
if stem in key:
|
1580 |
+
secondary_stem = value
|
1581 |
+
else:
|
1582 |
+
secondary_stem = stem.replace(NO_STEM, "") if NO_STEM in stem else f"{NO_STEM}{stem}"
|
1583 |
+
|
1584 |
+
return secondary_stem
|
gui_data/cr_text.txt
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Most Recent Changes:
|
2 |
+
|
3 |
+
Patch UVR_Patch_10_6_23_4_27:
|
4 |
+
~ MPS for MacOS is now compatible with all MDX-Net & VR Arch models!
|
5 |
+
~ Fixed memory issue with MDX23C models.
|
6 |
+
~ Added the ability to choose GPU device to process tracks.
|
7 |
+
~ Fixed a few graphical bugs.
|
8 |
+
|
9 |
+
Other Changes:
|
10 |
+
~ Added "Vocal Split Mode," a chain ensemble that utilizes karaoke & BVE
|
11 |
+
(backing vocal extractor) models to split vocals into lead vocal and backing vocal
|
12 |
+
stems.
|
13 |
+
~ Updated "Audio Tools" to include "Matchering".
|
14 |
+
~ The "Align Tool" has been improved to line up inputs even when there are differences
|
15 |
+
in timing, similar to what Utagoe does.
|
16 |
+
~ Integrated a right-click menu for entries in the batch process window specifically
|
17 |
+
for Matchering & Align tool.
|
18 |
+
~ Introduced a right-click option for Matchering & Align tool on the main window to
|
19 |
+
directly access input folders.
|
20 |
+
~ Addressed the anomaly where models would self-select in dropdowns.
|
21 |
+
~ This seemed related to having a multitude of models.
|
22 |
+
~ Revamped the dropdown menus entirely.
|
23 |
+
~ Fixed splash screen closing issue.
|
24 |
+
~ Demucs Pre-process Model settings properly reset to default
|
25 |
+
~ The MP3 encoder now uses LAME.
|
26 |
+
~ Resolved the problem of using the "MDX23C_D1581" model as a pre-process model for
|
27 |
+
Demucs conversions were leading to a "Key Error: 'Primary Stem'" error.
|
28 |
+
~ Rectified the "missing file" error that was popping up when trying to convert using
|
29 |
+
the Demucs v2 Demucs model.
|
30 |
+
~ Solved the problem in Demucs using the htdemucs_6s multi-stem model: when certain
|
31 |
+
stems were selected and pitch-shift conversion was used, a KeyError prevented
|
32 |
+
the stem from being saved. The stem separation now works correctly, and files are
|
33 |
+
saved to the intended destination.
|
34 |
+
~ Adjusted the "Reset All Settings To Default" function to ensure it also resets the
|
35 |
+
text of the Sample Mode slider.
|
36 |
+
~ Made corrections so that if you activate Model Test Mode and then restart the app,
|
37 |
+
the "Accept Any Input" option will not be inadvertently enabled.
|
38 |
+
~ Addressed the ensemble issue with MDX23C models.
|
39 |
+
~ Updated the "Select Stems" menu for 2 stem models under MDX23C.
|
40 |
+
~ Revamped the MDXC23 Overlap Menu.
|
41 |
+
~ Provided an option to decide on using the VR denoise model in MDX-NET.
|
42 |
+
~ Introduced a selection menu with choices: "None," "Standard," and "Denoise Model."
|
43 |
+
~ Added an option to automatically deverb vocals for Demucs & MDX models. (does not
|
44 |
+
work in ensemble mode at this time)
|
45 |
+
~ Updated the Download Center to feature MDX23C models.
|
46 |
+
~ Fixed the glitch where ensemble final outputs weren't saved.
|
47 |
+
~ Enhanced align/matchering outputs to support MP3 and FLAC.
|
48 |
+
~ Refined tooltips for "Overlap" and "Segment Default."
|
49 |
+
~ Ensured the download lists refresh properly even when there's no internet.
|
50 |
+
~ Many more fixes.
|
51 |
+
|
52 |
+
Resources & Credits:
|
53 |
+
|
54 |
+
----------------------------------------------------------------------------------------------
|
55 |
+
Name:
|
56 |
+
ZFTurbo
|
57 |
+
|
58 |
+
Contribution:
|
59 |
+
~ Created the weights for the new MDX23C models.
|
60 |
+
~ Trained the new MDX23C models.
|
61 |
+
|
62 |
+
Resources:
|
63 |
+
~ These models are available to use online. See link details below:
|
64 |
+
~ https://mvsep.com
|
65 |
+
~ Separation type -> MDX23C (vocals, instrumental)
|
66 |
+
~ Vocal model type -> 8K FFT, Full Band (SDR vocals: 10.17, SDR instrum: 16.48)
|
67 |
+
----------------------------------------------------------------------------------------------
|
68 |
+
Name:
|
69 |
+
Bas Curtiz
|
70 |
+
|
71 |
+
Contribution:
|
72 |
+
~ Conducted thorough testing and produced detailed tutorials on SDR evaluation.
|
73 |
+
~ Implemented systematic SDR quality assessments.
|
74 |
+
~ Authored Google Document detailing the optimal models for each genre.
|
75 |
+
~ Authored Google Document capturing UVR frequency cutoffs and time-elapsed comparisons between GPU and CPU performances.
|
76 |
+
~ Authored a Google Document consolidating best practices, tips, and tricks.
|
77 |
+
|
78 |
+
Resources:
|
79 |
+
~ SDR Per Genre Breakdown
|
80 |
+
~ UVR Model Frequency Cutoff Chart
|
81 |
+
~ Tips & Tricks
|
82 |
+
~ Link: (See "Discord Resources Link" below)
|
83 |
+
----------------------------------------------------------------------------------------------
|
84 |
+
Name:
|
85 |
+
deton24
|
86 |
+
|
87 |
+
Contribution:
|
88 |
+
~ SDR quality checks.
|
89 |
+
~ Authored a Google Document that serves as an all-in-one guide to source separation.
|
90 |
+
|
91 |
+
Resources:
|
92 |
+
~ All-in-One Source Separation Guide
|
93 |
+
~ Link: (See "Discord Resources Link" below)
|
94 |
+
----------------------------------------------------------------------------------------------
|
95 |
+
Name:
|
96 |
+
dca100fb8
|
97 |
+
|
98 |
+
Contribution:
|
99 |
+
~ Testing
|
100 |
+
~ Bug reporting
|
101 |
+
~ Suggestions
|
102 |
+
----------------------------------------------------------------------------------------------
|
103 |
+
|
104 |
+
Discord Resources Link: https://discord.com/channels/708579735583588363/1153421553694953482
|
gui_data/error_handling.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
import traceback
|
3 |
+
|
4 |
+
CUDA_MEMORY_ERROR = "CUDA out of memory"
|
5 |
+
CUDA_RUNTIME_ERROR = "CUDNN error executing cudnnSetTensorNdDescriptor"
|
6 |
+
DEMUCS_MODEL_MISSING_ERROR = "is neither a single pre-trained model or a bag of models."
|
7 |
+
ENSEMBLE_MISSING_MODEL_ERROR = "local variable \'enseExport\' referenced before assignment"
|
8 |
+
FFMPEG_MISSING_ERROR = """audioread\__init__.py", line 116, in audio_open"""
|
9 |
+
FILE_MISSING_ERROR = "FileNotFoundError"
|
10 |
+
MDX_MEMORY_ERROR = "onnxruntime::CudaCall CUDA failure 2: out of memory"
|
11 |
+
MDX_MODEL_MISSING = "[ONNXRuntimeError] : 3 : NO_SUCHFILE"
|
12 |
+
MDX_MODEL_SETTINGS_ERROR = "Got invalid dimensions for input"
|
13 |
+
MDX_RUNTIME_ERROR = "onnxruntime::BFCArena::AllocateRawInternal"
|
14 |
+
MODULE_ERROR = "ModuleNotFoundError"
|
15 |
+
WINDOW_SIZE_ERROR = "h1_shape[3] must be greater than h2_shape[3]"
|
16 |
+
SF_WRITE_ERROR = "sf.write"
|
17 |
+
SYSTEM_MEMORY_ERROR = "DefaultCPUAllocator: not enough memory"
|
18 |
+
MISSING_MODEL_ERROR = "'NoneType\' object has no attribute \'model_basename\'"
|
19 |
+
ARRAY_SIZE_ERROR = "ValueError: \"array is too big; `arr.size * arr.dtype.itemsize` is larger than the maximum possible size.\""
|
20 |
+
GPU_INCOMPATIBLE_ERROR = "no kernel image is available for execution on the device"
|
21 |
+
SELECT_CORRECT_GPU = "CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect."
|
22 |
+
|
23 |
+
CONTACT_DEV = 'If this error persists, please contact the developers with the error details.'
|
24 |
+
|
25 |
+
ERROR_MAPPER = {
|
26 |
+
CUDA_MEMORY_ERROR:
|
27 |
+
('The application was unable to allocate enough GPU memory to use this model. ' +
|
28 |
+
'Please close any GPU intensive applications and try again.\n' +
|
29 |
+
'If the error persists, your GPU might not be supported.') ,
|
30 |
+
CUDA_RUNTIME_ERROR:
|
31 |
+
(f'Your PC cannot process this audio file with the segment size selected. Please lower the segment size and try again.\n\n{CONTACT_DEV}'),
|
32 |
+
DEMUCS_MODEL_MISSING_ERROR:
|
33 |
+
('The selected Demucs model is missing. ' +
|
34 |
+
'Please download the model or make sure it is in the correct directory.'),
|
35 |
+
ENSEMBLE_MISSING_MODEL_ERROR:
|
36 |
+
('The application was unable to locate a model you selected for this ensemble.\n\n' +
|
37 |
+
'Please do the following to use all compatible models:\n\n1. Navigate to the \"Updates\" tab in the Help Guide.\n2. Download and install the model expansion pack.\n3. Then try again.\n\n' +
|
38 |
+
'If the error persists, please verify all models are present.'),
|
39 |
+
FFMPEG_MISSING_ERROR:
|
40 |
+
('The input file type is not supported or FFmpeg is missing. Please select a file type supported by FFmpeg and try again. ' +
|
41 |
+
'If FFmpeg is missing or not installed, you will only be able to process \".wav\" files until it is available on this system. ' +
|
42 |
+
f'See the \"More Info\" tab in the Help Guide.\n\n{CONTACT_DEV}'),
|
43 |
+
FILE_MISSING_ERROR:
|
44 |
+
(f'Missing file error raised. Please address the error and try again.\n\n{CONTACT_DEV}'),
|
45 |
+
MDX_MEMORY_ERROR:
|
46 |
+
('The application was unable to allocate enough GPU memory to use this model.\n\n' +
|
47 |
+
'Please do the following:\n\n1. Close any GPU intensive applications.\n2. Lower the set segment size.\n3. Then try again.\n\n' +
|
48 |
+
'If the error persists, your GPU might not be supported.'),
|
49 |
+
MDX_MODEL_MISSING:
|
50 |
+
('The application could not detect this MDX-Net model on your system. ' +
|
51 |
+
'Please make sure all the models are present in the correct directory.\n\n' +
|
52 |
+
'If the error persists, please reinstall application or contact the developers.'),
|
53 |
+
MDX_RUNTIME_ERROR:
|
54 |
+
('The application was unable to allocate enough GPU memory to use this model.\n\n' +
|
55 |
+
'Please do the following:\n\n1. Close any GPU intensive applications.\n2. Lower the set segment size.\n3. Then try again.\n\n' +
|
56 |
+
'If the error persists, your GPU might not be supported.'),
|
57 |
+
WINDOW_SIZE_ERROR:
|
58 |
+
('Invalid window size.\n\n' +
|
59 |
+
'The chosen window size is likely not compatible with this model. Please select a different size and try again.'),
|
60 |
+
SF_WRITE_ERROR:
|
61 |
+
('Could not write audio file.\n\n' +
|
62 |
+
'This could be due to one of the following:\n\n1. Low storage on target device.\n2. The export directory no longer exists.\n3. A system permissions issue.'),
|
63 |
+
SYSTEM_MEMORY_ERROR:
|
64 |
+
('The application was unable to allocate enough system memory to use this model.\n\n' +
|
65 |
+
'Please do the following:\n\n1. Restart this application.\n2. Ensure any CPU intensive applications are closed.\n3. Then try again.\n\n' +
|
66 |
+
'Please Note: Intel Pentium and Intel Celeron processors do not work well with this application.\n\n' +
|
67 |
+
'If the error persists, the system may not have enough RAM, or your CPU might not be supported.'),
|
68 |
+
MISSING_MODEL_ERROR:
|
69 |
+
('Model Missing: The application was unable to locate the chosen model.\n\n' +
|
70 |
+
'If the error persists, please verify any selected models are present.'),
|
71 |
+
GPU_INCOMPATIBLE_ERROR:
|
72 |
+
('This process is not compatible with your GPU.\n\n' +
|
73 |
+
'Please uncheck \"GPU Conversion\" and try again'),
|
74 |
+
SELECT_CORRECT_GPU:
|
75 |
+
('Make sure you\'ve chosen the correct GPU.\n\n'
|
76 |
+
'Go to the "Settings Guide", click the "Additional Settings" tab and select the correct GPU device.'),
|
77 |
+
ARRAY_SIZE_ERROR:
|
78 |
+
('The application was not able to process the given audiofile. Please convert the audiofile to another format and try again.'),
|
79 |
+
}
|
80 |
+
|
81 |
+
def error_text(process_method, exception):
|
82 |
+
|
83 |
+
traceback_text = ''.join(traceback.format_tb(exception.__traceback__))
|
84 |
+
message = f'{type(exception).__name__}: "{exception}"\nTraceback Error: "\n{traceback_text}"\n'
|
85 |
+
error_message = f'\n\nRaw Error Details:\n\n{message}\nError Time Stamp [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]\n'
|
86 |
+
process = f'Last Error Received:\n\nProcess: {process_method}\n\n'
|
87 |
+
|
88 |
+
for error_type, full_text in ERROR_MAPPER.items():
|
89 |
+
if error_type in message:
|
90 |
+
final_message = full_text
|
91 |
+
break
|
92 |
+
else:
|
93 |
+
final_message = (CONTACT_DEV)
|
94 |
+
|
95 |
+
return f"{process}{final_message}{error_message}"
|
96 |
+
|
97 |
+
def error_dialouge(exception):
|
98 |
+
|
99 |
+
error_name = f'{type(exception).__name__}'
|
100 |
+
traceback_text = ''.join(traceback.format_tb(exception.__traceback__))
|
101 |
+
message = f'{error_name}: "{exception}"\n{traceback_text}"'
|
102 |
+
|
103 |
+
for error_type, full_text in ERROR_MAPPER.items():
|
104 |
+
if error_type in message:
|
105 |
+
final_message = full_text
|
106 |
+
break
|
107 |
+
else:
|
108 |
+
final_message = (f'An Error Occurred: {error_name}\n\n{CONTACT_DEV}')
|
109 |
+
|
110 |
+
return final_message
|
gui_data/fail_chime.wav
ADDED
Binary file (385 kB). View file
|
|
gui_data/fonts/Montserrat/Montserrat.ttf
ADDED
Binary file (263 kB). View file
|
|
gui_data/fonts/centurygothic/GOTHIC.ttf
ADDED
Binary file (138 kB). View file
|
|
gui_data/fonts/other/own_font_goes_here.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
0
|
gui_data/img/File.png
ADDED
gui_data/img/GUI-Icon.ico
ADDED