lorneluo commited on
Commit
54c22e4
1 Parent(s): b3726c0
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. README2.md +256 -0
  3. UVR.py +0 -0
  4. __version__.py +4 -0
  5. demucs/__init__.py +5 -0
  6. demucs/__main__.py +272 -0
  7. demucs/__pycache__/__init__.cpython-310.pyc +0 -0
  8. demucs/__pycache__/apply.cpython-310.pyc +0 -0
  9. demucs/__pycache__/demucs.cpython-310.pyc +0 -0
  10. demucs/__pycache__/filtering.cpython-310.pyc +0 -0
  11. demucs/__pycache__/hdemucs.cpython-310.pyc +0 -0
  12. demucs/__pycache__/model.cpython-310.pyc +0 -0
  13. demucs/__pycache__/model_v2.cpython-310.pyc +0 -0
  14. demucs/__pycache__/pretrained.cpython-310.pyc +0 -0
  15. demucs/__pycache__/repo.cpython-310.pyc +0 -0
  16. demucs/__pycache__/spec.cpython-310.pyc +0 -0
  17. demucs/__pycache__/states.cpython-310.pyc +0 -0
  18. demucs/__pycache__/tasnet_v2.cpython-310.pyc +0 -0
  19. demucs/__pycache__/utils.cpython-310.pyc +0 -0
  20. demucs/apply.py +305 -0
  21. demucs/demucs.py +459 -0
  22. demucs/filtering.py +502 -0
  23. demucs/hdemucs.py +796 -0
  24. demucs/htdemucs.py +664 -0
  25. demucs/model.py +218 -0
  26. demucs/model_v2.py +218 -0
  27. demucs/pretrained.py +180 -0
  28. demucs/repo.py +148 -0
  29. demucs/spec.py +53 -0
  30. demucs/states.py +148 -0
  31. demucs/tasnet.py +447 -0
  32. demucs/tasnet_v2.py +452 -0
  33. demucs/transformer.py +839 -0
  34. demucs/utils.py +502 -0
  35. gui_data/__pycache__/app_size_values.cpython-310.pyc +0 -0
  36. gui_data/__pycache__/constants.cpython-310.pyc +0 -0
  37. gui_data/__pycache__/error_handling.cpython-310.pyc +0 -0
  38. gui_data/__pycache__/old_data_check.cpython-310.pyc +0 -0
  39. gui_data/app_size_values.py +371 -0
  40. gui_data/change_log.txt +93 -0
  41. gui_data/complete_chime.wav +0 -0
  42. gui_data/constants.py +1584 -0
  43. gui_data/cr_text.txt +104 -0
  44. gui_data/error_handling.py +110 -0
  45. gui_data/fail_chime.wav +0 -0
  46. gui_data/fonts/Montserrat/Montserrat.ttf +0 -0
  47. gui_data/fonts/centurygothic/GOTHIC.ttf +0 -0
  48. gui_data/fonts/other/own_font_goes_here.txt +1 -0
  49. gui_data/img/File.png +0 -0
  50. gui_data/img/GUI-Icon.ico +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv/
2
+ .idea/
README2.md ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultimate Vocal Remover GUI v5.6
2
+ <img src="https://raw.githubusercontent.com/Anjok07/ultimatevocalremovergui/master/gui_data/img/UVR_v5.6.png?raw=true" />
3
+
4
+ [![Release](https://img.shields.io/github/release/anjok07/ultimatevocalremovergui.svg)](https://github.com/anjok07/ultimatevocalremovergui/releases/latest)
5
+ [![Downloads](https://img.shields.io/github/downloads/anjok07/ultimatevocalremovergui/total.svg)](https://github.com/anjok07/ultimatevocalremovergui/releases)
6
+
7
+ ## About
8
+
9
+ This application uses state-of-the-art source separation models to remove vocals from audio files. UVR's core developers trained all of the models provided in this package (except for the Demucs v3 and v4 4-stem models).
10
+
11
+ - **Core Developers**
12
+ - [Anjok07](https://github.com/anjok07)
13
+ - [aufr33](https://github.com/aufr33)
14
+
15
+ - **Support the Project**
16
+ - [Donate](https://www.buymeacoffee.com/uvr5)
17
+
18
+ ## Installation
19
+
20
+ These bundles contain the UVR interface, Python, PyTorch, and other dependencies needed to run the application effectively. No prerequisites are required.
21
+
22
+ ### Windows Installation
23
+
24
+ - Please Note:
25
+ - This installer is intended for those running Windows 10 or higher.
26
+ - Application functionality for systems running Windows 7 or lower is not guaranteed.
27
+ - Application functionality for Intel Pentium & Celeron CPUs systems is not guaranteed.
28
+ - You must install UVR to the main C:\ drive. Installing UVR to a secondary drive will cause instability.
29
+
30
+ - Download the UVR installer for Windows via the link below:
31
+ - [Main Download Link](https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/UVR_v5.6.0_setup.exe)
32
+ - [Main Download Link mirror](https://www.mediafire.com/file_premium/jiatpgp0ljou52p/UVR_v5.6.0_setup.exe/file)
33
+ - If you use an **AMD Radeon or Intel Arc graphics card**, you can try the OpenCL version:
34
+ - [OpenCL Version - Main Download Link](https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/UVR_v5.6.0_setup_opencl.exe)
35
+ - Update Package instructions for those who have UVR already installed:
36
+ - If you already have UVR installed you can install this package over it or download it straight from the application or [click here for the patch](https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/UVR_Patch_10_6_23_4_27.exe).
37
+
38
+ <details id="WindowsManual">
39
+ <summary>Windows Manual Installation</summary>
40
+
41
+ ### Manual Windows Installation
42
+
43
+ - Download and extract the repository [here](https://github.com/Anjok07/ultimatevocalremovergui/archive/refs/heads/master.zip)
44
+ - Download and install Python [here](https://www.python.org/ftp/python/3.9.8/python-3.9.8-amd64.exe)
45
+ - Make sure to check "Add python.exe to PATH" during the install
46
+ - Run the following commands from the extracted repo directory:
47
+
48
+ ```
49
+ python.exe -m pip install -r requirements.txt
50
+ ```
51
+
52
+ If you have a compatible Nvidia GPU, run the following command:
53
+
54
+ ```
55
+ python.exe -m pip install --upgrade torch --extra-index-url https://download.pytorch.org/whl/cu117
56
+ ```
57
+
58
+ If you do not have FFmpeg or Rubber Band installed and want to avoid going through the process of installing them the long way, follow the instructions below.
59
+
60
+ **FFmpeg Installation**
61
+
62
+ - Download the precompiled build [here](https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-essentials.zip)
63
+ - From the archive, extract the following file to the UVR application directory:
64
+ - ```ffmpeg-5.1.2-essentials_build/bin/ffmpeg.exe```
65
+
66
+ **Rubber Band Installation**
67
+
68
+ In order to use the Time Stretch or Change Pitch tool, you'll need Rubber Band.
69
+
70
+ - Download the precompiled build [here](https://breakfastquay.com/files/releases/rubberband-3.1.2-gpl-executable-windows.zip)
71
+ - From the archive, extract the following files to the UVR application directory:
72
+ - ```rubberband-3.1.2-gpl-executable-windows/rubberband.exe```
73
+ - ```rubberband-3.1.2-gpl-executable-windows/sndfile.dll```
74
+
75
+ </details>
76
+
77
+ ### MacOS Installation
78
+ - Please Note:
79
+ - The MacOS Sonoma mouse clicking issue has been fixed.
80
+ - MPS (GPU) acceleration for Mac M1 has been expanded to work with Demucs v4 and all MDX-Net models.
81
+ - This bundle is intended for those running macOS Big Sur and above.
82
+ - Application functionality for systems running macOS Catalina or lower is not guaranteed.
83
+ - Application functionality for older or budget Mac systems is not guaranteed.
84
+ - Once everything is installed, the application may take up to 5-10 minutes to start for the first time (depending on your Macbook).
85
+
86
+ - Download the UVR dmg for MacOS via one of the links below:
87
+ - Mac M1 (arm64) users:
88
+ - [Main Download Link](https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/Ultimate_Vocal_Remover_v5_6_MacOS_arm64.dmg)
89
+ - [Main Download Link mirror](https://www.mediafire.com/file_premium/u3rk54wsqadpy93/Ultimate_Vocal_Remover_v5_6_MacOS_arm64.dmg/file)
90
+
91
+ - Mac Intel (x86_64) users:
92
+ - [Main Download Link](https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/Ultimate_Vocal_Remover_v5_6_MacOS_x86_64.dmg)
93
+ - [Main Download Link mirror](https://www.mediafire.com/file_premium/2gf1werx5ly5ylz/Ultimate_Vocal_Remover_v5_6_MacOS_x86_64.dmg/file)
94
+
95
+ <details id="CannotOpen">
96
+ <summary>MacOS Users: Having Trouble Opening UVR?</summary>
97
+
98
+ > Due to Apples strict application security, you may need to follow these steps to open UVR.
99
+ >
100
+ > First, run the following command via Terminal.app to allow applications to run from all sources (it's recommended that you re-enable this once UVR opens properly.)
101
+ >
102
+ > ```bash
103
+ > sudo spctl --master-disable
104
+ > ```
105
+ >
106
+ > Second, run the following command to bypass Notarization:
107
+ >
108
+ > ```bash
109
+ > sudo xattr -rd com.apple.quarantine /Applications/Ultimate\ Vocal\ Remover.app
110
+ > ```
111
+
112
+ </details>
113
+
114
+ <details id="MacInstall">
115
+ <summary>Manual MacOS Installation</summary>
116
+
117
+ ### Manual MacOS Installation
118
+
119
+ - Download and save this repository [here](https://github.com/Anjok07/ultimatevocalremovergui/archive/refs/heads/master.zip)
120
+ - Download and install Python 3.10 [here](https://www.python.org/ftp/python/3.10.9/python-3.10.9-macos11.pkg)
121
+ - From the saved directory run the following -
122
+
123
+ ```
124
+ pip3 install -r requirements.txt
125
+ ```
126
+
127
+ - If your Mac is running with an M1, please run the following command next. If not, skip this step. -
128
+
129
+ ```
130
+ cp /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/_soundfile_data/libsndfile_arm64.dylib /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/_soundfile_data/libsndfile.dylib
131
+ ```
132
+
133
+ **FFmpeg Installation**
134
+
135
+ - Once everything is done installing, download the correct FFmpeg binary for your system [here](http://www.osxexperts.net) and place it into the main application directory.
136
+
137
+ **Rubber Band Installation**
138
+
139
+ In order to use the Time Stretch or Change Pitch tool, you'll need Rubber Band.
140
+
141
+ - Download the precompiled build [here](https://breakfastquay.com/files/releases/rubberband-3.1.2-gpl-executable-windows.zip)
142
+ - From the archive, extract the following files to the UVR/lib_v5 application directory:
143
+ - ```rubberband-3.1.2-gpl-executable-macos/rubberband```
144
+
145
+ This process has been tested on a MacBook Pro 2021 (using M1) and a MacBook Air 2017 and is confirmed to be working on both.
146
+
147
+ </details>
148
+
149
+ ### Linux Installation
150
+
151
+ <details id="LinuxInstall">
152
+ <summary>See Linux Installation Instructions</summary>
153
+
154
+ <br />
155
+
156
+ **These install instructions are for Debian & Arch based Linux systems.**
157
+
158
+ - Download and save this repository [here](https://github.com/Anjok07/ultimatevocalremovergui/archive/refs/heads/master.zip)
159
+ - From the saved directory run the following commands in this order-
160
+
161
+ **For Debian Based (Ubuntu, Mint, etc.):**
162
+ ```
163
+ sudo apt update && sudo apt upgrade
164
+ sudo apt-get update
165
+ sudo apt install ffmpeg
166
+ sudo apt install python3-pip
167
+ sudo apt-get -y install python3-tk
168
+ pip3 install -r requirements.txt
169
+ python3 UVR.py
170
+ ```
171
+
172
+ **For Arch Based (EndeavourOS):**
173
+ ```
174
+ sudo pacman -Syu
175
+ sudo pacman -Sy
176
+ sudo pacman -S python-pip
177
+ sudo pacman -S --noconfirm tk
178
+ sudo pacman -S ffmpeg
179
+ ```
180
+
181
+ To bypass environment setup and proceed with the installation, use:
182
+
183
+ - Take caution; this modifies system files.
184
+
185
+ ```
186
+ sudo rm /usr/lib/python3.11/EXTERNALLY-MANAGED
187
+ ```
188
+
189
+ Then proceed with the following in order:
190
+
191
+ ```
192
+ chmod +x install_packages.sh
193
+ ./install_packages.sh
194
+ python UVR.py
195
+ ```
196
+
197
+ </details>
198
+
199
+ ### Other Application Notes
200
+ - Nvidia RTX 1060 6GB is the minimum requirement for GPU conversions.
201
+ - Nvidia GPUs with at least 8GBs of V-RAM are recommended.
202
+ - AMD Radeon GPU supported is limited at this time.
203
+ - There is currently a working branch for AMD GPU users [here](https://github.com/Anjok07/ultimatevocalremovergui/tree/v5.6-amd-gpu)
204
+ - This application is only compatible with 64-bit platforms.
205
+ - This application relies on the Rubber Band library for the Time-Stretch and Pitch-Shift options.
206
+ - This application relies on FFmpeg to process non-wav audio files.
207
+ - The application will automatically remember your settings when closed.
208
+ - Conversion times will significantly depend on your hardware.
209
+ - These models are computationally intensive.
210
+
211
+ ### Performance:
212
+ - Model load times are faster.
213
+ - Importing/exporting audio files is faster.
214
+
215
+ ## Troubleshooting
216
+
217
+ ### Common Issues
218
+
219
+ - If FFmpeg is not installed, the application will throw an error if the user attempts to convert a non-WAV file.
220
+ - Memory allocation errors can usually be resolved by lowering the "Segment" or "Window" sizes.
221
+
222
+ #### MacOS Sonoma Left-click Bug
223
+ There's a known issue on MacOS Sonoma where left-clicks aren't registering correctly within the app. This was impacting all applications built with Tkinter on Sonoma and has since been resolved. Please download the latest version via the following link if you are still experiencing issues - [link](https://github.com/Anjok07/ultimatevocalremovergui/releases/tag/v5.6)
224
+
225
+ This issue was being tracked [here](https://github.com/Anjok07/ultimatevocalremovergui/issues/840).
226
+
227
+ ### Issue Reporting
228
+
229
+ Please be as detailed as possible when posting a new issue.
230
+
231
+ If possible, click the "Settings Button" to the left of the "Start Processing" button and click the "Error Log" button for detailed error information that can be provided to us.
232
+
233
+ ## License
234
+
235
+ The **Ultimate Vocal Remover GUI** code is [MIT-licensed](LICENSE).
236
+
237
+ - **Please Note:** For all third-party application developers who wish to use our models, please honor the MIT license by providing credit to UVR and its developers.
238
+
239
+ ## Credits
240
+ - [ZFTurbo](https://github.com/ZFTurbo) - Created & trained the weights for the new MDX23C models.
241
+ - [DilanBoskan](https://github.com/DilanBoskan) - Your contributions at the start of this project were essential to the success of UVR. Thank you!
242
+ - [Bas Curtiz](https://www.youtube.com/user/bascurtiz) - Designed the official UVR logo, icon, banner, and splash screen.
243
+ - [tsurumeso](https://github.com/tsurumeso) - Developed the original VR Architecture code.
244
+ - [Kuielab & Woosung Choi](https://github.com/kuielab) - Developed the original MDX-Net AI code.
245
+ - [Adefossez & Demucs](https://github.com/facebookresearch/demucs) - Developed the original Demucs AI code.
246
+ - [KimberleyJSN](https://github.com/KimberleyJensen) - Advised and aided the implementation of the training scripts for MDX-Net and Demucs. Thank you!
247
+ - [Hv](https://github.com/NaJeongMo/Colab-for-MDX_B) - Helped implement chunks into the MDX-Net AI code. Thank you!
248
+
249
+ ## Contributing
250
+
251
+ - For anyone interested in the ongoing development of **Ultimate Vocal Remover GUI**, please send us a pull request, and we will review it.
252
+ - This project is 100% open-source and free for anyone to use and modify as they wish.
253
+ - We only maintain the development and support for the **Ultimate Vocal Remover GUI** and the models provided.
254
+
255
+ ## References
256
+ - [1] Takahashi et al., "Multi-scale Multi-band DenseNets for Audio Source Separation", https://arxiv.org/pdf/1706.09588.pdf
UVR.py ADDED
The diff for this file is too large to render. See raw diff
 
__version__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ VERSION = 'v5.6.0'
2
+ PATCH = 'UVR_Patch_9_29_23_1_39'
3
+ PATCH_MAC = 'UVR_Patch_9_29_23_1_39'
4
+ PATCH_LINUX = 'UVR_Patch_9_29_23_1_39'
demucs/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
demucs/__main__.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from fractions import Fraction
13
+
14
+ import torch as th
15
+ from torch import distributed, nn
16
+ from torch.nn.parallel.distributed import DistributedDataParallel
17
+
18
+ from .augment import FlipChannels, FlipSign, Remix, Shift
19
+ from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks
20
+ from .model import Demucs
21
+ from .parser import get_name, get_parser
22
+ from .raw import Rawset
23
+ from .tasnet import ConvTasNet
24
+ from .test import evaluate
25
+ from .train import train_model, validate_model
26
+ from .utils import human_seconds, load_model, save_model, sizeof_fmt
27
+
28
+
29
+ @dataclass
30
+ class SavedState:
31
+ metrics: list = field(default_factory=list)
32
+ last_state: dict = None
33
+ best_state: dict = None
34
+ optimizer: dict = None
35
+
36
+
37
+ def main():
38
+ parser = get_parser()
39
+ args = parser.parse_args()
40
+ name = get_name(parser, args)
41
+ print(f"Experiment {name}")
42
+
43
+ if args.musdb is None and args.rank == 0:
44
+ print(
45
+ "You must provide the path to the MusDB dataset with the --musdb flag. "
46
+ "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
47
+ file=sys.stderr)
48
+ sys.exit(1)
49
+
50
+ eval_folder = args.evals / name
51
+ eval_folder.mkdir(exist_ok=True, parents=True)
52
+ args.logs.mkdir(exist_ok=True)
53
+ metrics_path = args.logs / f"{name}.json"
54
+ eval_folder.mkdir(exist_ok=True, parents=True)
55
+ args.checkpoints.mkdir(exist_ok=True, parents=True)
56
+ args.models.mkdir(exist_ok=True, parents=True)
57
+
58
+ if args.device is None:
59
+ device = "cpu"
60
+ if th.cuda.is_available():
61
+ device = "cuda"
62
+ else:
63
+ device = args.device
64
+
65
+ th.manual_seed(args.seed)
66
+ # Prevents too many threads to be started when running `museval` as it can be quite
67
+ # inefficient on NUMA architectures.
68
+ os.environ["OMP_NUM_THREADS"] = "1"
69
+
70
+ if args.world_size > 1:
71
+ if device != "cuda" and args.rank == 0:
72
+ print("Error: distributed training is only available with cuda device", file=sys.stderr)
73
+ sys.exit(1)
74
+ th.cuda.set_device(args.rank % th.cuda.device_count())
75
+ distributed.init_process_group(backend="nccl",
76
+ init_method="tcp://" + args.master,
77
+ rank=args.rank,
78
+ world_size=args.world_size)
79
+
80
+ checkpoint = args.checkpoints / f"{name}.th"
81
+ checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
82
+ if args.restart and checkpoint.exists():
83
+ checkpoint.unlink()
84
+
85
+ if args.test:
86
+ args.epochs = 1
87
+ args.repeat = 0
88
+ model = load_model(args.models / args.test)
89
+ elif args.tasnet:
90
+ model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X)
91
+ else:
92
+ model = Demucs(
93
+ audio_channels=args.audio_channels,
94
+ channels=args.channels,
95
+ context=args.context,
96
+ depth=args.depth,
97
+ glu=args.glu,
98
+ growth=args.growth,
99
+ kernel_size=args.kernel_size,
100
+ lstm_layers=args.lstm_layers,
101
+ rescale=args.rescale,
102
+ rewrite=args.rewrite,
103
+ sources=4,
104
+ stride=args.conv_stride,
105
+ upsample=args.upsample,
106
+ samplerate=args.samplerate
107
+ )
108
+ model.to(device)
109
+ if args.show:
110
+ print(model)
111
+ size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
112
+ print(f"Model size {size}")
113
+ return
114
+
115
+ optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
116
+
117
+ try:
118
+ saved = th.load(checkpoint, map_location='cpu')
119
+ except IOError:
120
+ saved = SavedState()
121
+ else:
122
+ model.load_state_dict(saved.last_state)
123
+ optimizer.load_state_dict(saved.optimizer)
124
+
125
+ if args.save_model:
126
+ if args.rank == 0:
127
+ model.to("cpu")
128
+ model.load_state_dict(saved.best_state)
129
+ save_model(model, args.models / f"{name}.th")
130
+ return
131
+
132
+ if args.rank == 0:
133
+ done = args.logs / f"{name}.done"
134
+ if done.exists():
135
+ done.unlink()
136
+
137
+ if args.augment:
138
+ augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride),
139
+ Remix(group_size=args.remix_group_size)).to(device)
140
+ else:
141
+ augment = Shift(args.data_stride)
142
+
143
+ if args.mse:
144
+ criterion = nn.MSELoss()
145
+ else:
146
+ criterion = nn.L1Loss()
147
+
148
+ # Setting number of samples so that all convolution windows are full.
149
+ # Prevents hard to debug mistake with the prediction being shifted compared
150
+ # to the input mixture.
151
+ samples = model.valid_length(args.samples)
152
+ print(f"Number of training samples adjusted to {samples}")
153
+
154
+ if args.raw:
155
+ train_set = Rawset(args.raw / "train",
156
+ samples=samples + args.data_stride,
157
+ channels=args.audio_channels,
158
+ streams=[0, 1, 2, 3, 4],
159
+ stride=args.data_stride)
160
+
161
+ valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
162
+ else:
163
+ if not args.metadata.is_file() and args.rank == 0:
164
+ build_musdb_metadata(args.metadata, args.musdb, args.workers)
165
+ if args.world_size > 1:
166
+ distributed.barrier()
167
+ metadata = json.load(open(args.metadata))
168
+ duration = Fraction(samples + args.data_stride, args.samplerate)
169
+ stride = Fraction(args.data_stride, args.samplerate)
170
+ train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
171
+ metadata,
172
+ duration=duration,
173
+ stride=stride,
174
+ samplerate=args.samplerate,
175
+ channels=args.audio_channels)
176
+ valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
177
+ metadata,
178
+ samplerate=args.samplerate,
179
+ channels=args.audio_channels)
180
+
181
+ best_loss = float("inf")
182
+ for epoch, metrics in enumerate(saved.metrics):
183
+ print(f"Epoch {epoch:03d}: "
184
+ f"train={metrics['train']:.8f} "
185
+ f"valid={metrics['valid']:.8f} "
186
+ f"best={metrics['best']:.4f} "
187
+ f"duration={human_seconds(metrics['duration'])}")
188
+ best_loss = metrics['best']
189
+
190
+ if args.world_size > 1:
191
+ dmodel = DistributedDataParallel(model,
192
+ device_ids=[th.cuda.current_device()],
193
+ output_device=th.cuda.current_device())
194
+ else:
195
+ dmodel = model
196
+
197
+ for epoch in range(len(saved.metrics), args.epochs):
198
+ begin = time.time()
199
+ model.train()
200
+ train_loss = train_model(epoch,
201
+ train_set,
202
+ dmodel,
203
+ criterion,
204
+ optimizer,
205
+ augment,
206
+ batch_size=args.batch_size,
207
+ device=device,
208
+ repeat=args.repeat,
209
+ seed=args.seed,
210
+ workers=args.workers,
211
+ world_size=args.world_size)
212
+ model.eval()
213
+ valid_loss = validate_model(epoch,
214
+ valid_set,
215
+ model,
216
+ criterion,
217
+ device=device,
218
+ rank=args.rank,
219
+ split=args.split_valid,
220
+ world_size=args.world_size)
221
+
222
+ duration = time.time() - begin
223
+ if valid_loss < best_loss:
224
+ best_loss = valid_loss
225
+ saved.best_state = {
226
+ key: value.to("cpu").clone()
227
+ for key, value in model.state_dict().items()
228
+ }
229
+ saved.metrics.append({
230
+ "train": train_loss,
231
+ "valid": valid_loss,
232
+ "best": best_loss,
233
+ "duration": duration
234
+ })
235
+ if args.rank == 0:
236
+ json.dump(saved.metrics, open(metrics_path, "w"))
237
+
238
+ saved.last_state = model.state_dict()
239
+ saved.optimizer = optimizer.state_dict()
240
+ if args.rank == 0 and not args.test:
241
+ th.save(saved, checkpoint_tmp)
242
+ checkpoint_tmp.rename(checkpoint)
243
+
244
+ print(f"Epoch {epoch:03d}: "
245
+ f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} "
246
+ f"duration={human_seconds(duration)}")
247
+
248
+ del dmodel
249
+ model.load_state_dict(saved.best_state)
250
+ if args.eval_cpu:
251
+ device = "cpu"
252
+ model.to(device)
253
+ model.eval()
254
+ evaluate(model,
255
+ args.musdb,
256
+ eval_folder,
257
+ rank=args.rank,
258
+ world_size=args.world_size,
259
+ device=device,
260
+ save=args.save,
261
+ split=args.split_valid,
262
+ shifts=args.shifts,
263
+ workers=args.eval_workers)
264
+ model.to("cpu")
265
+ save_model(model, args.models / f"{name}.th")
266
+ if args.rank == 0:
267
+ print("done")
268
+ done.write_text("done")
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
demucs/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (166 Bytes). View file
 
demucs/__pycache__/apply.cpython-310.pyc ADDED
Binary file (8.15 kB). View file
 
demucs/__pycache__/demucs.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
demucs/__pycache__/filtering.cpython-310.pyc ADDED
Binary file (17 kB). View file
 
demucs/__pycache__/hdemucs.cpython-310.pyc ADDED
Binary file (21.3 kB). View file
 
demucs/__pycache__/model.cpython-310.pyc ADDED
Binary file (6.25 kB). View file
 
demucs/__pycache__/model_v2.cpython-310.pyc ADDED
Binary file (6.33 kB). View file
 
demucs/__pycache__/pretrained.cpython-310.pyc ADDED
Binary file (5.07 kB). View file
 
demucs/__pycache__/repo.cpython-310.pyc ADDED
Binary file (5.97 kB). View file
 
demucs/__pycache__/spec.cpython-310.pyc ADDED
Binary file (1.26 kB). View file
 
demucs/__pycache__/states.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
demucs/__pycache__/tasnet_v2.cpython-310.pyc ADDED
Binary file (12 kB). View file
 
demucs/__pycache__/utils.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
demucs/apply.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Code to apply a model to a mix. It will handle chunking with overlaps and
8
+ inteprolation between chunks, as well as the "shift trick".
9
+ """
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ import random
12
+ import typing as tp
13
+ from multiprocessing import Process,Queue,Pipe
14
+
15
+ import torch as th
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ import tqdm
19
+ import tkinter as tk
20
+
21
+ from .demucs import Demucs
22
+ from .hdemucs import HDemucs
23
+ from .utils import center_trim, DummyPoolExecutor
24
+
25
+ Model = tp.Union[Demucs, HDemucs]
26
+
27
+ progress_bar_num = 0
28
+
29
+ class BagOfModels(nn.Module):
30
+ def __init__(self, models: tp.List[Model],
31
+ weights: tp.Optional[tp.List[tp.List[float]]] = None,
32
+ segment: tp.Optional[float] = None):
33
+ """
34
+ Represents a bag of models with specific weights.
35
+ You should call `apply_model` rather than calling directly the forward here for
36
+ optimal performance.
37
+
38
+ Args:
39
+ models (list[nn.Module]): list of Demucs/HDemucs models.
40
+ weights (list[list[float]]): list of weights. If None, assumed to
41
+ be all ones, otherwise it should be a list of N list (N number of models),
42
+ each containing S floats (S number of sources).
43
+ segment (None or float): overrides the `segment` attribute of each model
44
+ (this is performed inplace, be careful if you reuse the models passed).
45
+ """
46
+
47
+ super().__init__()
48
+ assert len(models) > 0
49
+ first = models[0]
50
+ for other in models:
51
+ assert other.sources == first.sources
52
+ assert other.samplerate == first.samplerate
53
+ assert other.audio_channels == first.audio_channels
54
+ if segment is not None:
55
+ other.segment = segment
56
+
57
+ self.audio_channels = first.audio_channels
58
+ self.samplerate = first.samplerate
59
+ self.sources = first.sources
60
+ self.models = nn.ModuleList(models)
61
+
62
+ if weights is None:
63
+ weights = [[1. for _ in first.sources] for _ in models]
64
+ else:
65
+ assert len(weights) == len(models)
66
+ for weight in weights:
67
+ assert len(weight) == len(first.sources)
68
+ self.weights = weights
69
+
70
+ def forward(self, x):
71
+ raise NotImplementedError("Call `apply_model` on this.")
72
+
73
+ class TensorChunk:
74
+ def __init__(self, tensor, offset=0, length=None):
75
+ total_length = tensor.shape[-1]
76
+ assert offset >= 0
77
+ assert offset < total_length
78
+
79
+ if length is None:
80
+ length = total_length - offset
81
+ else:
82
+ length = min(total_length - offset, length)
83
+
84
+ if isinstance(tensor, TensorChunk):
85
+ self.tensor = tensor.tensor
86
+ self.offset = offset + tensor.offset
87
+ else:
88
+ self.tensor = tensor
89
+ self.offset = offset
90
+ self.length = length
91
+ self.device = tensor.device
92
+
93
+ @property
94
+ def shape(self):
95
+ shape = list(self.tensor.shape)
96
+ shape[-1] = self.length
97
+ return shape
98
+
99
+ def padded(self, target_length):
100
+ delta = target_length - self.length
101
+ total_length = self.tensor.shape[-1]
102
+ assert delta >= 0
103
+
104
+ start = self.offset - delta // 2
105
+ end = start + target_length
106
+
107
+ correct_start = max(0, start)
108
+ correct_end = min(total_length, end)
109
+
110
+ pad_left = correct_start - start
111
+ pad_right = end - correct_end
112
+
113
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
114
+ assert out.shape[-1] == target_length
115
+ return out
116
+
117
+ def tensor_chunk(tensor_or_chunk):
118
+ if isinstance(tensor_or_chunk, TensorChunk):
119
+ return tensor_or_chunk
120
+ else:
121
+ assert isinstance(tensor_or_chunk, th.Tensor)
122
+ return TensorChunk(tensor_or_chunk)
123
+
124
+ def apply_model(model,
125
+ mix,
126
+ shifts=1,
127
+ split=True,
128
+ overlap=0.25,
129
+ transition_power=1.,
130
+ static_shifts=1,
131
+ set_progress_bar=None,
132
+ device=None,
133
+ progress=False,
134
+ num_workers=0,
135
+ pool=None):
136
+ """
137
+ Apply model to a given mixture.
138
+
139
+ Args:
140
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
141
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
142
+ all predictions are averaged. This effectively makes the model time equivariant
143
+ and improves SDR by up to 0.2 points.
144
+ split (bool): if True, the input will be broken down in 8 seconds extracts
145
+ and predictions will be performed individually on each and concatenated.
146
+ Useful for model with large memory footprint like Tasnet.
147
+ progress (bool): if True, show a progress bar (requires split=True)
148
+ device (torch.device, str, or None): if provided, device on which to
149
+ execute the computation, otherwise `mix.device` is assumed.
150
+ When `device` is different from `mix.device`, only local computations will
151
+ be on `device`, while the entire tracks will be stored on `mix.device`.
152
+ """
153
+
154
+ global fut_length
155
+ global bag_num
156
+ global prog_bar
157
+
158
+ if device is None:
159
+ device = mix.device
160
+ else:
161
+ device = th.device(device)
162
+ if pool is None:
163
+ if num_workers > 0 and device.type == 'cpu':
164
+ pool = ThreadPoolExecutor(num_workers)
165
+ else:
166
+ pool = DummyPoolExecutor()
167
+
168
+ kwargs = {
169
+ 'shifts': shifts,
170
+ 'split': split,
171
+ 'overlap': overlap,
172
+ 'transition_power': transition_power,
173
+ 'progress': progress,
174
+ 'device': device,
175
+ 'pool': pool,
176
+ 'set_progress_bar': set_progress_bar,
177
+ 'static_shifts': static_shifts,
178
+ }
179
+
180
+ if isinstance(model, BagOfModels):
181
+ # Special treatment for bag of model.
182
+ # We explicitely apply multiple times `apply_model` so that the random shifts
183
+ # are different for each model.
184
+
185
+ estimates = 0
186
+ totals = [0] * len(model.sources)
187
+ bag_num = len(model.models)
188
+ fut_length = 0
189
+ prog_bar = 0
190
+ current_model = 0 #(bag_num + 1)
191
+ for sub_model, weight in zip(model.models, model.weights):
192
+ original_model_device = next(iter(sub_model.parameters())).device
193
+ sub_model.to(device)
194
+ fut_length += fut_length
195
+ current_model += 1
196
+ out = apply_model(sub_model, mix, **kwargs)
197
+ sub_model.to(original_model_device)
198
+ for k, inst_weight in enumerate(weight):
199
+ out[:, k, :, :] *= inst_weight
200
+ totals[k] += inst_weight
201
+ estimates += out
202
+ del out
203
+
204
+ for k in range(estimates.shape[1]):
205
+ estimates[:, k, :, :] /= totals[k]
206
+ return estimates
207
+
208
+ model.to(device)
209
+ model.eval()
210
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
211
+ batch, channels, length = mix.shape
212
+
213
+ if shifts:
214
+ kwargs['shifts'] = 0
215
+ max_shift = int(0.5 * model.samplerate)
216
+ mix = tensor_chunk(mix)
217
+ padded_mix = mix.padded(length + 2 * max_shift)
218
+ out = 0
219
+ for _ in range(shifts):
220
+ offset = random.randint(0, max_shift)
221
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
222
+ shifted_out = apply_model(model, shifted, **kwargs)
223
+ out += shifted_out[..., max_shift - offset:]
224
+ out /= shifts
225
+ return out
226
+ elif split:
227
+ kwargs['split'] = False
228
+ out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
229
+ sum_weight = th.zeros(length, device=mix.device)
230
+ segment = int(model.samplerate * model.segment)
231
+ stride = int((1 - overlap) * segment)
232
+ offsets = range(0, length, stride)
233
+ scale = float(format(stride / model.samplerate, ".2f"))
234
+ # We start from a triangle shaped weight, with maximal weight in the middle
235
+ # of the segment. Then we normalize and take to the power `transition_power`.
236
+ # Large values of transition power will lead to sharper transitions.
237
+ weight = th.cat([th.arange(1, segment // 2 + 1, device=device),
238
+ th.arange(segment - segment // 2, 0, -1, device=device)])
239
+ assert len(weight) == segment
240
+ # If the overlap < 50%, this will translate to linear transition when
241
+ # transition_power is 1.
242
+ weight = (weight / weight.max())**transition_power
243
+ futures = []
244
+ for offset in offsets:
245
+ chunk = TensorChunk(mix, offset, segment)
246
+ future = pool.submit(apply_model, model, chunk, **kwargs)
247
+ futures.append((future, offset))
248
+ offset += segment
249
+ if progress:
250
+ futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
251
+ for future, offset in futures:
252
+ if set_progress_bar:
253
+ fut_length = (len(futures) * bag_num * static_shifts)
254
+ prog_bar += 1
255
+ set_progress_bar(0.1, (0.8/fut_length*prog_bar))
256
+ chunk_out = future.result()
257
+ chunk_length = chunk_out.shape[-1]
258
+ out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
259
+ sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device)
260
+ assert sum_weight.min() > 0
261
+ out /= sum_weight
262
+ return out
263
+ else:
264
+ if hasattr(model, 'valid_length'):
265
+ valid_length = model.valid_length(length)
266
+ else:
267
+ valid_length = length
268
+ mix = tensor_chunk(mix)
269
+ padded_mix = mix.padded(valid_length).to(device)
270
+ with th.no_grad():
271
+ out = model(padded_mix)
272
+ return center_trim(out, length)
273
+
274
+ def demucs_segments(demucs_segment, demucs_model):
275
+
276
+ if demucs_segment == 'Default':
277
+ segment = None
278
+ if isinstance(demucs_model, BagOfModels):
279
+ if segment is not None:
280
+ for sub in demucs_model.models:
281
+ sub.segment = segment
282
+ else:
283
+ if segment is not None:
284
+ sub.segment = segment
285
+ else:
286
+ try:
287
+ segment = int(demucs_segment)
288
+ if isinstance(demucs_model, BagOfModels):
289
+ if segment is not None:
290
+ for sub in demucs_model.models:
291
+ sub.segment = segment
292
+ else:
293
+ if segment is not None:
294
+ sub.segment = segment
295
+ except:
296
+ segment = None
297
+ if isinstance(demucs_model, BagOfModels):
298
+ if segment is not None:
299
+ for sub in demucs_model.models:
300
+ sub.segment = segment
301
+ else:
302
+ if segment is not None:
303
+ sub.segment = segment
304
+
305
+ return demucs_model
demucs/demucs.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import julius
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from .states import capture_init
16
+ from .utils import center_trim, unfold
17
+
18
+
19
+ class BLSTM(nn.Module):
20
+ """
21
+ BiLSTM with same hidden units as input dim.
22
+ If `max_steps` is not None, input will be splitting in overlapping
23
+ chunks and the LSTM applied separately on each chunk.
24
+ """
25
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
26
+ super().__init__()
27
+ assert max_steps is None or max_steps % 4 == 0
28
+ self.max_steps = max_steps
29
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
30
+ self.linear = nn.Linear(2 * dim, dim)
31
+ self.skip = skip
32
+
33
+ def forward(self, x):
34
+ B, C, T = x.shape
35
+ y = x
36
+ framed = False
37
+ if self.max_steps is not None and T > self.max_steps:
38
+ width = self.max_steps
39
+ stride = width // 2
40
+ frames = unfold(x, width, stride)
41
+ nframes = frames.shape[2]
42
+ framed = True
43
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
44
+
45
+ x = x.permute(2, 0, 1)
46
+
47
+ x = self.lstm(x)[0]
48
+ x = self.linear(x)
49
+ x = x.permute(1, 2, 0)
50
+ if framed:
51
+ out = []
52
+ frames = x.reshape(B, -1, C, width)
53
+ limit = stride // 2
54
+ for k in range(nframes):
55
+ if k == 0:
56
+ out.append(frames[:, k, :, :-limit])
57
+ elif k == nframes - 1:
58
+ out.append(frames[:, k, :, limit:])
59
+ else:
60
+ out.append(frames[:, k, :, limit:-limit])
61
+ out = torch.cat(out, -1)
62
+ out = out[..., :T]
63
+ x = out
64
+ if self.skip:
65
+ x = x + y
66
+ return x
67
+
68
+
69
+ def rescale_conv(conv, reference):
70
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does.
71
+ """
72
+ std = conv.weight.std().detach()
73
+ scale = (std / reference)**0.5
74
+ conv.weight.data /= scale
75
+ if conv.bias is not None:
76
+ conv.bias.data /= scale
77
+
78
+
79
+ def rescale_module(module, reference):
80
+ for sub in module.modules():
81
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
82
+ rescale_conv(sub, reference)
83
+
84
+
85
+ class LayerScale(nn.Module):
86
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
87
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
88
+ """
89
+ def __init__(self, channels: int, init: float = 0):
90
+ super().__init__()
91
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
92
+ self.scale.data[:] = init
93
+
94
+ def forward(self, x):
95
+ return self.scale[:, None] * x
96
+
97
+
98
+ class DConv(nn.Module):
99
+ """
100
+ New residual branches in each encoder layer.
101
+ This alternates dilated convolutions, potentially with LSTMs and attention.
102
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
103
+ e.g. of dim `channels // compress`.
104
+ """
105
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
106
+ norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
107
+ kernel=3, dilate=True):
108
+ """
109
+ Args:
110
+ channels: input/output channels for residual branch.
111
+ compress: amount of channel compression inside the branch.
112
+ depth: number of layers in the residual branch. Each layer has its own
113
+ projection, and potentially LSTM and attention.
114
+ init: initial scale for LayerNorm.
115
+ norm: use GroupNorm.
116
+ attn: use LocalAttention.
117
+ heads: number of heads for the LocalAttention.
118
+ ndecay: number of decay controls in the LocalAttention.
119
+ lstm: use LSTM.
120
+ gelu: Use GELU activation.
121
+ kernel: kernel size for the (dilated) convolutions.
122
+ dilate: if true, use dilation, increasing with the depth.
123
+ """
124
+
125
+ super().__init__()
126
+ assert kernel % 2 == 1
127
+ self.channels = channels
128
+ self.compress = compress
129
+ self.depth = abs(depth)
130
+ dilate = depth > 0
131
+
132
+ norm_fn: tp.Callable[[int], nn.Module]
133
+ norm_fn = lambda d: nn.Identity() # noqa
134
+ if norm:
135
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
136
+
137
+ hidden = int(channels / compress)
138
+
139
+ act: tp.Type[nn.Module]
140
+ if gelu:
141
+ act = nn.GELU
142
+ else:
143
+ act = nn.ReLU
144
+
145
+ self.layers = nn.ModuleList([])
146
+ for d in range(self.depth):
147
+ dilation = 2 ** d if dilate else 1
148
+ padding = dilation * (kernel // 2)
149
+ mods = [
150
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
151
+ norm_fn(hidden), act(),
152
+ nn.Conv1d(hidden, 2 * channels, 1),
153
+ norm_fn(2 * channels), nn.GLU(1),
154
+ LayerScale(channels, init),
155
+ ]
156
+ if attn:
157
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
158
+ if lstm:
159
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
160
+ layer = nn.Sequential(*mods)
161
+ self.layers.append(layer)
162
+
163
+ def forward(self, x):
164
+ for layer in self.layers:
165
+ x = x + layer(x)
166
+ return x
167
+
168
+
169
+ class LocalState(nn.Module):
170
+ """Local state allows to have attention based only on data (no positional embedding),
171
+ but while setting a constraint on the time window (e.g. decaying penalty term).
172
+
173
+ Also a failed experiments with trying to provide some frequency based attention.
174
+ """
175
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
176
+ super().__init__()
177
+ assert channels % heads == 0, (channels, heads)
178
+ self.heads = heads
179
+ self.nfreqs = nfreqs
180
+ self.ndecay = ndecay
181
+ self.content = nn.Conv1d(channels, channels, 1)
182
+ self.query = nn.Conv1d(channels, channels, 1)
183
+ self.key = nn.Conv1d(channels, channels, 1)
184
+ if nfreqs:
185
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
186
+ if ndecay:
187
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
188
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
189
+ self.query_decay.weight.data *= 0.01
190
+ assert self.query_decay.bias is not None # stupid type checker
191
+ self.query_decay.bias.data[:] = -2
192
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
193
+
194
+ def forward(self, x):
195
+ B, C, T = x.shape
196
+ heads = self.heads
197
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
198
+ # left index are keys, right index are queries
199
+ delta = indexes[:, None] - indexes[None, :]
200
+
201
+ queries = self.query(x).view(B, heads, -1, T)
202
+ keys = self.key(x).view(B, heads, -1, T)
203
+ # t are keys, s are queries
204
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
205
+ dots /= keys.shape[2]**0.5
206
+ if self.nfreqs:
207
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
208
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
209
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
210
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
211
+ if self.ndecay:
212
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
213
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
214
+ decay_q = torch.sigmoid(decay_q) / 2
215
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
216
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
217
+
218
+ # Kill self reference.
219
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
220
+ weights = torch.softmax(dots, dim=2)
221
+
222
+ content = self.content(x).view(B, heads, -1, T)
223
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
224
+ if self.nfreqs:
225
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
226
+ result = torch.cat([result, time_sig], 2)
227
+ result = result.reshape(B, -1, T)
228
+ return x + self.proj(result)
229
+
230
+
231
+ class Demucs(nn.Module):
232
+ @capture_init
233
+ def __init__(self,
234
+ sources,
235
+ # Channels
236
+ audio_channels=2,
237
+ channels=64,
238
+ growth=2.,
239
+ # Main structure
240
+ depth=6,
241
+ rewrite=True,
242
+ lstm_layers=0,
243
+ # Convolutions
244
+ kernel_size=8,
245
+ stride=4,
246
+ context=1,
247
+ # Activations
248
+ gelu=True,
249
+ glu=True,
250
+ # Normalization
251
+ norm_starts=4,
252
+ norm_groups=4,
253
+ # DConv residual branch
254
+ dconv_mode=1,
255
+ dconv_depth=2,
256
+ dconv_comp=4,
257
+ dconv_attn=4,
258
+ dconv_lstm=4,
259
+ dconv_init=1e-4,
260
+ # Pre/post processing
261
+ normalize=True,
262
+ resample=True,
263
+ # Weight init
264
+ rescale=0.1,
265
+ # Metadata
266
+ samplerate=44100,
267
+ segment=4 * 10):
268
+ """
269
+ Args:
270
+ sources (list[str]): list of source names
271
+ audio_channels (int): stereo or mono
272
+ channels (int): first convolution channels
273
+ depth (int): number of encoder/decoder layers
274
+ growth (float): multiply (resp divide) number of channels by that
275
+ for each layer of the encoder (resp decoder)
276
+ depth (int): number of layers in the encoder and in the decoder.
277
+ rewrite (bool): add 1x1 convolution to each layer.
278
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
279
+ by default, as this is now replaced by the smaller and faster small LSTMs
280
+ in the DConv branches.
281
+ kernel_size (int): kernel size for convolutions
282
+ stride (int): stride for convolutions
283
+ context (int): kernel size of the convolution in the
284
+ decoder before the transposed convolution. If > 1,
285
+ will provide some context from neighboring time steps.
286
+ gelu: use GELU activation function.
287
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
288
+ norm_starts: layer at which group norm starts being used.
289
+ decoder layers are numbered in reverse order.
290
+ norm_groups: number of groups for group norm.
291
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
292
+ dconv_depth: depth of residual DConv branch.
293
+ dconv_comp: compression of DConv branch.
294
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
295
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
296
+ dconv_init: initial scale for the DConv branch LayerScale.
297
+ normalize (bool): normalizes the input audio on the fly, and scales back
298
+ the output by the same amount.
299
+ resample (bool): upsample x2 the input and downsample /2 the output.
300
+ rescale (int): rescale initial weights of convolutions
301
+ to get their standard deviation closer to `rescale`.
302
+ samplerate (int): stored as meta information for easing
303
+ future evaluations of the model.
304
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
305
+ This is used by `demucs.apply.apply_model`.
306
+ """
307
+
308
+ super().__init__()
309
+ self.audio_channels = audio_channels
310
+ self.sources = sources
311
+ self.kernel_size = kernel_size
312
+ self.context = context
313
+ self.stride = stride
314
+ self.depth = depth
315
+ self.resample = resample
316
+ self.channels = channels
317
+ self.normalize = normalize
318
+ self.samplerate = samplerate
319
+ self.segment = segment
320
+ self.encoder = nn.ModuleList()
321
+ self.decoder = nn.ModuleList()
322
+ self.skip_scales = nn.ModuleList()
323
+
324
+ if glu:
325
+ activation = nn.GLU(dim=1)
326
+ ch_scale = 2
327
+ else:
328
+ activation = nn.ReLU()
329
+ ch_scale = 1
330
+ if gelu:
331
+ act2 = nn.GELU
332
+ else:
333
+ act2 = nn.ReLU
334
+
335
+ in_channels = audio_channels
336
+ padding = 0
337
+ for index in range(depth):
338
+ norm_fn = lambda d: nn.Identity() # noqa
339
+ if index >= norm_starts:
340
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
341
+
342
+ encode = []
343
+ encode += [
344
+ nn.Conv1d(in_channels, channels, kernel_size, stride),
345
+ norm_fn(channels),
346
+ act2(),
347
+ ]
348
+ attn = index >= dconv_attn
349
+ lstm = index >= dconv_lstm
350
+ if dconv_mode & 1:
351
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
352
+ compress=dconv_comp, attn=attn, lstm=lstm)]
353
+ if rewrite:
354
+ encode += [
355
+ nn.Conv1d(channels, ch_scale * channels, 1),
356
+ norm_fn(ch_scale * channels), activation]
357
+ self.encoder.append(nn.Sequential(*encode))
358
+
359
+ decode = []
360
+ if index > 0:
361
+ out_channels = in_channels
362
+ else:
363
+ out_channels = len(self.sources) * audio_channels
364
+ if rewrite:
365
+ decode += [
366
+ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
367
+ norm_fn(ch_scale * channels), activation]
368
+ if dconv_mode & 2:
369
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
370
+ compress=dconv_comp, attn=attn, lstm=lstm)]
371
+ decode += [nn.ConvTranspose1d(channels, out_channels,
372
+ kernel_size, stride, padding=padding)]
373
+ if index > 0:
374
+ decode += [norm_fn(out_channels), act2()]
375
+ self.decoder.insert(0, nn.Sequential(*decode))
376
+ in_channels = channels
377
+ channels = int(growth * channels)
378
+
379
+ channels = in_channels
380
+ if lstm_layers:
381
+ self.lstm = BLSTM(channels, lstm_layers)
382
+ else:
383
+ self.lstm = None
384
+
385
+ if rescale:
386
+ rescale_module(self, reference=rescale)
387
+
388
+ def valid_length(self, length):
389
+ """
390
+ Return the nearest valid length to use with the model so that
391
+ there is no time steps left over in a convolution, e.g. for all
392
+ layers, size of the input - kernel_size % stride = 0.
393
+
394
+ Note that input are automatically padded if necessary to ensure that the output
395
+ has the same length as the input.
396
+ """
397
+ if self.resample:
398
+ length *= 2
399
+
400
+ for _ in range(self.depth):
401
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
402
+ length = max(1, length)
403
+
404
+ for idx in range(self.depth):
405
+ length = (length - 1) * self.stride + self.kernel_size
406
+
407
+ if self.resample:
408
+ length = math.ceil(length / 2)
409
+ return int(length)
410
+
411
+ def forward(self, mix):
412
+ x = mix
413
+ length = x.shape[-1]
414
+
415
+ if self.normalize:
416
+ mono = mix.mean(dim=1, keepdim=True)
417
+ mean = mono.mean(dim=-1, keepdim=True)
418
+ std = mono.std(dim=-1, keepdim=True)
419
+ x = (x - mean) / (1e-5 + std)
420
+ else:
421
+ mean = 0
422
+ std = 1
423
+
424
+ delta = self.valid_length(length) - length
425
+ x = F.pad(x, (delta // 2, delta - delta // 2))
426
+
427
+ if self.resample:
428
+ x = julius.resample_frac(x, 1, 2)
429
+
430
+ saved = []
431
+ for encode in self.encoder:
432
+ x = encode(x)
433
+ saved.append(x)
434
+
435
+ if self.lstm:
436
+ x = self.lstm(x)
437
+
438
+ for decode in self.decoder:
439
+ skip = saved.pop(-1)
440
+ skip = center_trim(skip, x)
441
+ x = decode(x + skip)
442
+
443
+ if self.resample:
444
+ x = julius.resample_frac(x, 2, 1)
445
+ x = x * std + mean
446
+ x = center_trim(x, length)
447
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
448
+ return x
449
+
450
+ def load_state_dict(self, state, strict=True):
451
+ # fix a mismatch with previous generation Demucs models.
452
+ for idx in range(self.depth):
453
+ for a in ['encoder', 'decoder']:
454
+ for b in ['bias', 'weight']:
455
+ new = f'{a}.{idx}.3.{b}'
456
+ old = f'{a}.{idx}.2.{b}'
457
+ if old in state and new not in state:
458
+ state[new] = state.pop(old)
459
+ super().load_state_dict(state, strict=strict)
demucs/filtering.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+ from torch.utils.data import DataLoader
6
+
7
+ def atan2(y, x):
8
+ r"""Element-wise arctangent function of y/x.
9
+ Returns a new tensor with signed angles in radians.
10
+ It is an alternative implementation of torch.atan2
11
+
12
+ Args:
13
+ y (Tensor): First input tensor
14
+ x (Tensor): Second input tensor [shape=y.shape]
15
+
16
+ Returns:
17
+ Tensor: [shape=y.shape].
18
+ """
19
+ pi = 2 * torch.asin(torch.tensor(1.0))
20
+ x += ((x == 0) & (y == 0)) * 1.0
21
+ out = torch.atan(y / x)
22
+ out += ((y >= 0) & (x < 0)) * pi
23
+ out -= ((y < 0) & (x < 0)) * pi
24
+ out *= 1 - ((y > 0) & (x == 0)) * 1.0
25
+ out += ((y > 0) & (x == 0)) * (pi / 2)
26
+ out *= 1 - ((y < 0) & (x == 0)) * 1.0
27
+ out += ((y < 0) & (x == 0)) * (-pi / 2)
28
+ return out
29
+
30
+
31
+ # Define basic complex operations on torch.Tensor objects whose last dimension
32
+ # consists in the concatenation of the real and imaginary parts.
33
+
34
+
35
+ def _norm(x: torch.Tensor) -> torch.Tensor:
36
+ r"""Computes the norm value of a torch Tensor, assuming that it
37
+ comes as real and imaginary part in its last dimension.
38
+
39
+ Args:
40
+ x (Tensor): Input Tensor of shape [shape=(..., 2)]
41
+
42
+ Returns:
43
+ Tensor: shape as x excluding the last dimension.
44
+ """
45
+ return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
46
+
47
+
48
+ def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
49
+ """Element-wise multiplication of two complex Tensors described
50
+ through their real and imaginary parts.
51
+ The result is added to the `out` tensor"""
52
+
53
+ # check `out` and allocate it if needed
54
+ target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
55
+ if out is None or out.shape != target_shape:
56
+ out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
57
+ if out is a:
58
+ real_a = a[..., 0]
59
+ out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
60
+ out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
61
+ else:
62
+ out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
63
+ out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
64
+ return out
65
+
66
+
67
+ def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
68
+ """Element-wise multiplication of two complex Tensors described
69
+ through their real and imaginary parts
70
+ can work in place in case out is a only"""
71
+ target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
72
+ if out is None or out.shape != target_shape:
73
+ out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
74
+ if out is a:
75
+ real_a = a[..., 0]
76
+ out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
77
+ out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
78
+ else:
79
+ out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
80
+ out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
81
+ return out
82
+
83
+
84
+ def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
85
+ """Element-wise multiplicative inverse of a Tensor with complex
86
+ entries described through their real and imaginary parts.
87
+ can work in place in case out is z"""
88
+ ez = _norm(z)
89
+ if out is None or out.shape != z.shape:
90
+ out = torch.zeros_like(z)
91
+ out[..., 0] = z[..., 0] / ez
92
+ out[..., 1] = -z[..., 1] / ez
93
+ return out
94
+
95
+
96
+ def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
97
+ """Element-wise complex conjugate of a Tensor with complex entries
98
+ described through their real and imaginary parts.
99
+ can work in place in case out is z"""
100
+ if out is None or out.shape != z.shape:
101
+ out = torch.zeros_like(z)
102
+ out[..., 0] = z[..., 0]
103
+ out[..., 1] = -z[..., 1]
104
+ return out
105
+
106
+
107
+ def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
108
+ """
109
+ Invert 1x1 or 2x2 matrices
110
+
111
+ Will generate errors if the matrices are singular: user must handle this
112
+ through his own regularization schemes.
113
+
114
+ Args:
115
+ M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
116
+ matrices to invert: must be square along dimensions -3 and -2
117
+
118
+ Returns:
119
+ invM (Tensor): [shape=M.shape]
120
+ inverses of M
121
+ """
122
+ nb_channels = M.shape[-2]
123
+
124
+ if out is None or out.shape != M.shape:
125
+ out = torch.empty_like(M)
126
+
127
+ if nb_channels == 1:
128
+ # scalar case
129
+ out = _inv(M, out)
130
+ elif nb_channels == 2:
131
+ # two channels case: analytical expression
132
+
133
+ # first compute the determinent
134
+ det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
135
+ det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
136
+ # invert it
137
+ invDet = _inv(det)
138
+
139
+ # then fill out the matrix with the inverse
140
+ out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
141
+ out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
142
+ out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
143
+ out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
144
+ else:
145
+ raise Exception("Only 2 channels are supported for the torch version.")
146
+ return out
147
+
148
+
149
+ # Now define the signal-processing low-level functions used by the Separator
150
+
151
+
152
+ def expectation_maximization(
153
+ y: torch.Tensor,
154
+ x: torch.Tensor,
155
+ iterations: int = 2,
156
+ eps: float = 1e-10,
157
+ batch_size: int = 200,
158
+ ):
159
+ r"""Expectation maximization algorithm, for refining source separation
160
+ estimates.
161
+
162
+ This algorithm allows to make source separation results better by
163
+ enforcing multichannel consistency for the estimates. This usually means
164
+ a better perceptual quality in terms of spatial artifacts.
165
+
166
+ The implementation follows the details presented in [1]_, taking
167
+ inspiration from the original EM algorithm proposed in [2]_ and its
168
+ weighted refinement proposed in [3]_, [4]_.
169
+ It works by iteratively:
170
+
171
+ * Re-estimate source parameters (power spectral densities and spatial
172
+ covariance matrices) through :func:`get_local_gaussian_model`.
173
+
174
+ * Separate again the mixture with the new parameters by first computing
175
+ the new modelled mixture covariance matrices with :func:`get_mix_model`,
176
+ prepare the Wiener filters through :func:`wiener_gain` and apply them
177
+ with :func:`apply_filter``.
178
+
179
+ References
180
+ ----------
181
+ .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
182
+ N. Takahashi and Y. Mitsufuji, "Improving music source separation based
183
+ on deep neural networks through data augmentation and network
184
+ blending." 2017 IEEE International Conference on Acoustics, Speech
185
+ and Signal Processing (ICASSP). IEEE, 2017.
186
+
187
+ .. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined
188
+ reverberant audio source separation using a full-rank spatial
189
+ covariance model." IEEE Transactions on Audio, Speech, and Language
190
+ Processing 18.7 (2010): 1830-1840.
191
+
192
+ .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
193
+ separation with deep neural networks." IEEE/ACM Transactions on Audio,
194
+ Speech, and Language Processing 24.9 (2016): 1652-1664.
195
+
196
+ .. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
197
+ separation with deep neural networks." 2016 24th European Signal
198
+ Processing Conference (EUSIPCO). IEEE, 2016.
199
+
200
+ .. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
201
+ source separation." IEEE Transactions on Signal Processing
202
+ 62.16 (2014): 4298-4310.
203
+
204
+ Args:
205
+ y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
206
+ initial estimates for the sources
207
+ x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
208
+ complex STFT of the mixture signal
209
+ iterations (int): [scalar]
210
+ number of iterations for the EM algorithm.
211
+ eps (float or None): [scalar]
212
+ The epsilon value to use for regularization and filters.
213
+
214
+ Returns:
215
+ y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
216
+ estimated sources after iterations
217
+ v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
218
+ estimated power spectral densities
219
+ R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
220
+ estimated spatial covariance matrices
221
+
222
+ Notes:
223
+ * You need an initial estimate for the sources to apply this
224
+ algorithm. This is precisely what the :func:`wiener` function does.
225
+ * This algorithm *is not* an implementation of the "exact" EM
226
+ proposed in [1]_. In particular, it does compute the posterior
227
+ covariance matrices the same (exact) way. Instead, it uses the
228
+ simplified approximate scheme initially proposed in [5]_ and further
229
+ refined in [3]_, [4]_, that boils down to just take the empirical
230
+ covariance of the recent source estimates, followed by a weighted
231
+ average for the update of the spatial covariance matrix. It has been
232
+ empirically demonstrated that this simplified algorithm is more
233
+ robust for music separation.
234
+
235
+ Warning:
236
+ It is *very* important to make sure `x.dtype` is `torch.float64`
237
+ if you want double precision, because this function will **not**
238
+ do such conversion for you from `torch.complex32`, in case you want the
239
+ smaller RAM usage on purpose.
240
+
241
+ It is usually always better in terms of quality to have double
242
+ precision, by e.g. calling :func:`expectation_maximization`
243
+ with ``x.to(torch.float64)``.
244
+ """
245
+ # dimensions
246
+ (nb_frames, nb_bins, nb_channels) = x.shape[:-1]
247
+ nb_sources = y.shape[-1]
248
+
249
+ regularization = torch.cat(
250
+ (
251
+ torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None],
252
+ torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device),
253
+ ),
254
+ dim=2,
255
+ )
256
+ regularization = torch.sqrt(torch.as_tensor(eps)) * (
257
+ regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1))
258
+ )
259
+
260
+ # allocate the spatial covariance matrices
261
+ R = [
262
+ torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device)
263
+ for j in range(nb_sources)
264
+ ]
265
+ weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
266
+
267
+ v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
268
+ for it in range(iterations):
269
+ # constructing the mixture covariance matrix. Doing it with a loop
270
+ # to avoid storing anytime in RAM the whole 6D tensor
271
+
272
+ # update the PSD as the average spectrogram over channels
273
+ v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
274
+
275
+ # update spatial covariance matrices (weighted update)
276
+ for j in range(nb_sources):
277
+ R[j] = torch.tensor(0.0, device=x.device)
278
+ weight = torch.tensor(eps, device=x.device)
279
+ pos: int = 0
280
+ batch_size = batch_size if batch_size else nb_frames
281
+ while pos < nb_frames:
282
+ t = torch.arange(pos, min(nb_frames, pos + batch_size))
283
+ pos = int(t[-1]) + 1
284
+
285
+ R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
286
+ weight = weight + torch.sum(v[t, ..., j], dim=0)
287
+ R[j] = R[j] / weight[..., None, None, None]
288
+ weight = torch.zeros_like(weight)
289
+
290
+ # cloning y if we track gradient, because we're going to update it
291
+ if y.requires_grad:
292
+ y = y.clone()
293
+
294
+ pos = 0
295
+ while pos < nb_frames:
296
+ t = torch.arange(pos, min(nb_frames, pos + batch_size))
297
+ pos = int(t[-1]) + 1
298
+
299
+ y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
300
+
301
+ # compute mix covariance matrix
302
+ Cxx = regularization
303
+ for j in range(nb_sources):
304
+ Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
305
+
306
+ # invert it
307
+ inv_Cxx = _invert(Cxx)
308
+
309
+ # separate the sources
310
+ for j in range(nb_sources):
311
+
312
+ # create a wiener gain for this source
313
+ gain = torch.zeros_like(inv_Cxx)
314
+
315
+ # computes multichannel Wiener gain as v_j R_j inv_Cxx
316
+ indices = torch.cartesian_prod(
317
+ torch.arange(nb_channels),
318
+ torch.arange(nb_channels),
319
+ torch.arange(nb_channels),
320
+ )
321
+ for index in indices:
322
+ gain[:, :, index[0], index[1], :] = _mul_add(
323
+ R[j][None, :, index[0], index[2], :].clone(),
324
+ inv_Cxx[:, :, index[2], index[1], :],
325
+ gain[:, :, index[0], index[1], :],
326
+ )
327
+ gain = gain * v[t, ..., None, None, None, j]
328
+
329
+ # apply it to the mixture
330
+ for i in range(nb_channels):
331
+ y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
332
+
333
+ return y, v, R
334
+
335
+
336
+ def wiener(
337
+ targets_spectrograms: torch.Tensor,
338
+ mix_stft: torch.Tensor,
339
+ iterations: int = 1,
340
+ softmask: bool = False,
341
+ residual: bool = False,
342
+ scale_factor: float = 10.0,
343
+ eps: float = 1e-10,
344
+ ):
345
+ """Wiener-based separation for multichannel audio.
346
+
347
+ The method uses the (possibly multichannel) spectrograms of the
348
+ sources to separate the (complex) Short Term Fourier Transform of the
349
+ mix. Separation is done in a sequential way by:
350
+
351
+ * Getting an initial estimate. This can be done in two ways: either by
352
+ directly using the spectrograms with the mixture phase, or
353
+ by using a softmasking strategy. This initial phase is controlled
354
+ by the `softmask` flag.
355
+
356
+ * If required, adding an additional residual target as the mix minus
357
+ all targets.
358
+
359
+ * Refinining these initial estimates through a call to
360
+ :func:`expectation_maximization` if the number of iterations is nonzero.
361
+
362
+ This implementation also allows to specify the epsilon value used for
363
+ regularization. It is based on [1]_, [2]_, [3]_, [4]_.
364
+
365
+ References
366
+ ----------
367
+ .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
368
+ N. Takahashi and Y. Mitsufuji, "Improving music source separation based
369
+ on deep neural networks through data augmentation and network
370
+ blending." 2017 IEEE International Conference on Acoustics, Speech
371
+ and Signal Processing (ICASSP). IEEE, 2017.
372
+
373
+ .. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
374
+ separation with deep neural networks." IEEE/ACM Transactions on Audio,
375
+ Speech, and Language Processing 24.9 (2016): 1652-1664.
376
+
377
+ .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
378
+ separation with deep neural networks." 2016 24th European Signal
379
+ Processing Conference (EUSIPCO). IEEE, 2016.
380
+
381
+ .. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
382
+ source separation." IEEE Transactions on Signal Processing
383
+ 62.16 (2014): 4298-4310.
384
+
385
+ Args:
386
+ targets_spectrograms (Tensor): spectrograms of the sources
387
+ [shape=(nb_frames, nb_bins, nb_channels, nb_sources)].
388
+ This is a nonnegative tensor that is
389
+ usually the output of the actual separation method of the user. The
390
+ spectrograms may be mono, but they need to be 4-dimensional in all
391
+ cases.
392
+ mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)]
393
+ STFT of the mixture signal.
394
+ iterations (int): [scalar]
395
+ number of iterations for the EM algorithm
396
+ softmask (bool): Describes how the initial estimates are obtained.
397
+ * if `False`, then the mixture phase will directly be used with the
398
+ spectrogram as initial estimates.
399
+ * if `True`, initial estimates are obtained by multiplying the
400
+ complex mix element-wise with the ratio of each target spectrogram
401
+ with the sum of them all. This strategy is better if the model are
402
+ not really good, and worse otherwise.
403
+ residual (bool): if `True`, an additional target is created, which is
404
+ equal to the mixture minus the other targets, before application of
405
+ expectation maximization
406
+ eps (float): Epsilon value to use for computing the separations.
407
+ This is used whenever division with a model energy is
408
+ performed, i.e. when softmasking and when iterating the EM.
409
+ It can be understood as the energy of the additional white noise
410
+ that is taken out when separating.
411
+
412
+ Returns:
413
+ Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
414
+ STFT of estimated sources
415
+
416
+ Notes:
417
+ * Be careful that you need *magnitude spectrogram estimates* for the
418
+ case `softmask==False`.
419
+ * `softmask=False` is recommended
420
+ * The epsilon value will have a huge impact on performance. If it's
421
+ large, only the parts of the signal with a significant energy will
422
+ be kept in the sources. This epsilon then directly controls the
423
+ energy of the reconstruction error.
424
+
425
+ Warning:
426
+ As in :func:`expectation_maximization`, we recommend converting the
427
+ mixture `x` to double precision `torch.float64` *before* calling
428
+ :func:`wiener`.
429
+ """
430
+ if softmask:
431
+ # if we use softmask, we compute the ratio mask for all targets and
432
+ # multiply by the mix stft
433
+ y = (
434
+ mix_stft[..., None]
435
+ * (
436
+ targets_spectrograms
437
+ / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype))
438
+ )[..., None, :]
439
+ )
440
+ else:
441
+ # otherwise, we just multiply the targets spectrograms with mix phase
442
+ # we tacitly assume that we have magnitude estimates.
443
+ angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
444
+ nb_sources = targets_spectrograms.shape[-1]
445
+ y = torch.zeros(
446
+ mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device
447
+ )
448
+ y[..., 0, :] = targets_spectrograms * torch.cos(angle)
449
+ y[..., 1, :] = targets_spectrograms * torch.sin(angle)
450
+
451
+ if residual:
452
+ # if required, adding an additional target as the mix minus
453
+ # available targets
454
+ y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
455
+
456
+ if iterations == 0:
457
+ return y
458
+
459
+ # we need to refine the estimates. Scales down the estimates for
460
+ # numerical stability
461
+ max_abs = torch.max(
462
+ torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device),
463
+ torch.sqrt(_norm(mix_stft)).max() / scale_factor,
464
+ )
465
+
466
+ mix_stft = mix_stft / max_abs
467
+ y = y / max_abs
468
+
469
+ # call expectation maximization
470
+ y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
471
+
472
+ # scale estimates up again
473
+ y = y * max_abs
474
+ return y
475
+
476
+
477
+ def _covariance(y_j):
478
+ """
479
+ Compute the empirical covariance for a source.
480
+
481
+ Args:
482
+ y_j (Tensor): complex stft of the source.
483
+ [shape=(nb_frames, nb_bins, nb_channels, 2)].
484
+
485
+ Returns:
486
+ Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
487
+ just y_j * conj(y_j.T): empirical covariance for each TF bin.
488
+ """
489
+ (nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
490
+ Cj = torch.zeros(
491
+ (nb_frames, nb_bins, nb_channels, nb_channels, 2),
492
+ dtype=y_j.dtype,
493
+ device=y_j.device,
494
+ )
495
+ indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
496
+ for index in indices:
497
+ Cj[:, :, index[0], index[1], :] = _mul_add(
498
+ y_j[:, :, index[0], :],
499
+ _conj(y_j[:, :, index[1], :]),
500
+ Cj[:, :, index[0], index[1], :],
501
+ )
502
+ return Cj
demucs/hdemucs.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ This code contains the spectrogram and Hybrid version of Demucs.
8
+ """
9
+ from copy import deepcopy
10
+ import math
11
+ import typing as tp
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ from .filtering import wiener
16
+ from .demucs import DConv, rescale_module
17
+ from .states import capture_init
18
+ from .spec import spectro, ispectro
19
+
20
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
21
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
22
+ If this is the case, we insert extra 0 padding to the right before the reflection happen."""
23
+ x0 = x
24
+ length = x.shape[-1]
25
+ padding_left, padding_right = paddings
26
+ if mode == 'reflect':
27
+ max_pad = max(padding_left, padding_right)
28
+ if length <= max_pad:
29
+ extra_pad = max_pad - length + 1
30
+ extra_pad_right = min(padding_right, extra_pad)
31
+ extra_pad_left = extra_pad - extra_pad_right
32
+ paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
33
+ x = F.pad(x, (extra_pad_left, extra_pad_right))
34
+ out = F.pad(x, paddings, mode, value)
35
+ assert out.shape[-1] == length + padding_left + padding_right
36
+ assert (out[..., padding_left: padding_left + length] == x0).all()
37
+ return out
38
+
39
+ class ScaledEmbedding(nn.Module):
40
+ """
41
+ Boost learning rate for embeddings (with `scale`).
42
+ Also, can make embeddings continuous with `smooth`.
43
+ """
44
+ def __init__(self, num_embeddings: int, embedding_dim: int,
45
+ scale: float = 10., smooth=False):
46
+ super().__init__()
47
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
48
+ if smooth:
49
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
50
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
51
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
52
+ self.embedding.weight.data[:] = weight
53
+ self.embedding.weight.data /= scale
54
+ self.scale = scale
55
+
56
+ @property
57
+ def weight(self):
58
+ return self.embedding.weight * self.scale
59
+
60
+ def forward(self, x):
61
+ out = self.embedding(x) * self.scale
62
+ return out
63
+
64
+
65
+ class HEncLayer(nn.Module):
66
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
67
+ freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
68
+ rewrite=True):
69
+ """Encoder layer. This used both by the time and the frequency branch.
70
+
71
+ Args:
72
+ chin: number of input channels.
73
+ chout: number of output channels.
74
+ norm_groups: number of groups for group norm.
75
+ empty: used to make a layer with just the first conv. this is used
76
+ before merging the time and freq. branches.
77
+ freq: this is acting on frequencies.
78
+ dconv: insert DConv residual branches.
79
+ norm: use GroupNorm.
80
+ context: context size for the 1x1 conv.
81
+ dconv_kw: list of kwargs for the DConv class.
82
+ pad: pad the input. Padding is done so that the output size is
83
+ always the input size / stride.
84
+ rewrite: add 1x1 conv at the end of the layer.
85
+ """
86
+ super().__init__()
87
+ norm_fn = lambda d: nn.Identity() # noqa
88
+ if norm:
89
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
90
+ if pad:
91
+ pad = kernel_size // 4
92
+ else:
93
+ pad = 0
94
+ klass = nn.Conv1d
95
+ self.freq = freq
96
+ self.kernel_size = kernel_size
97
+ self.stride = stride
98
+ self.empty = empty
99
+ self.norm = norm
100
+ self.pad = pad
101
+ if freq:
102
+ kernel_size = [kernel_size, 1]
103
+ stride = [stride, 1]
104
+ pad = [pad, 0]
105
+ klass = nn.Conv2d
106
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
107
+ if self.empty:
108
+ return
109
+ self.norm1 = norm_fn(chout)
110
+ self.rewrite = None
111
+ if rewrite:
112
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
113
+ self.norm2 = norm_fn(2 * chout)
114
+
115
+ self.dconv = None
116
+ if dconv:
117
+ self.dconv = DConv(chout, **dconv_kw)
118
+
119
+ def forward(self, x, inject=None):
120
+ """
121
+ `inject` is used to inject the result from the time branch into the frequency branch,
122
+ when both have the same stride.
123
+ """
124
+ if not self.freq and x.dim() == 4:
125
+ B, C, Fr, T = x.shape
126
+ x = x.view(B, -1, T)
127
+
128
+ if not self.freq:
129
+ le = x.shape[-1]
130
+ if not le % self.stride == 0:
131
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
132
+ y = self.conv(x)
133
+ if self.empty:
134
+ return y
135
+ if inject is not None:
136
+ assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
137
+ if inject.dim() == 3 and y.dim() == 4:
138
+ inject = inject[:, :, None]
139
+ y = y + inject
140
+ y = F.gelu(self.norm1(y))
141
+ if self.dconv:
142
+ if self.freq:
143
+ B, C, Fr, T = y.shape
144
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
145
+ y = self.dconv(y)
146
+ if self.freq:
147
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
148
+ if self.rewrite:
149
+ z = self.norm2(self.rewrite(y))
150
+ z = F.glu(z, dim=1)
151
+ else:
152
+ z = y
153
+ return z
154
+
155
+
156
+ class MultiWrap(nn.Module):
157
+ """
158
+ Takes one layer and replicate it N times. each replica will act
159
+ on a frequency band. All is done so that if the N replica have the same weights,
160
+ then this is exactly equivalent to applying the original module on all frequencies.
161
+
162
+ This is a bit over-engineered to avoid edge artifacts when splitting
163
+ the frequency bands, but it is possible the naive implementation would work as well...
164
+ """
165
+ def __init__(self, layer, split_ratios):
166
+ """
167
+ Args:
168
+ layer: module to clone, must be either HEncLayer or HDecLayer.
169
+ split_ratios: list of float indicating which ratio to keep for each band.
170
+ """
171
+ super().__init__()
172
+ self.split_ratios = split_ratios
173
+ self.layers = nn.ModuleList()
174
+ self.conv = isinstance(layer, HEncLayer)
175
+ assert not layer.norm
176
+ assert layer.freq
177
+ assert layer.pad
178
+ if not self.conv:
179
+ assert not layer.context_freq
180
+ for k in range(len(split_ratios) + 1):
181
+ lay = deepcopy(layer)
182
+ if self.conv:
183
+ lay.conv.padding = (0, 0)
184
+ else:
185
+ lay.pad = False
186
+ for m in lay.modules():
187
+ if hasattr(m, 'reset_parameters'):
188
+ m.reset_parameters()
189
+ self.layers.append(lay)
190
+
191
+ def forward(self, x, skip=None, length=None):
192
+ B, C, Fr, T = x.shape
193
+
194
+ ratios = list(self.split_ratios) + [1]
195
+ start = 0
196
+ outs = []
197
+ for ratio, layer in zip(ratios, self.layers):
198
+ if self.conv:
199
+ pad = layer.kernel_size // 4
200
+ if ratio == 1:
201
+ limit = Fr
202
+ frames = -1
203
+ else:
204
+ limit = int(round(Fr * ratio))
205
+ le = limit - start
206
+ if start == 0:
207
+ le += pad
208
+ frames = round((le - layer.kernel_size) / layer.stride + 1)
209
+ limit = start + (frames - 1) * layer.stride + layer.kernel_size
210
+ if start == 0:
211
+ limit -= pad
212
+ assert limit - start > 0, (limit, start)
213
+ assert limit <= Fr, (limit, Fr)
214
+ y = x[:, :, start:limit, :]
215
+ if start == 0:
216
+ y = F.pad(y, (0, 0, pad, 0))
217
+ if ratio == 1:
218
+ y = F.pad(y, (0, 0, 0, pad))
219
+ outs.append(layer(y))
220
+ start = limit - layer.kernel_size + layer.stride
221
+ else:
222
+ if ratio == 1:
223
+ limit = Fr
224
+ else:
225
+ limit = int(round(Fr * ratio))
226
+ last = layer.last
227
+ layer.last = True
228
+
229
+ y = x[:, :, start:limit]
230
+ s = skip[:, :, start:limit]
231
+ out, _ = layer(y, s, None)
232
+ if outs:
233
+ outs[-1][:, :, -layer.stride:] += (
234
+ out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
235
+ out = out[:, :, layer.stride:]
236
+ if ratio == 1:
237
+ out = out[:, :, :-layer.stride // 2, :]
238
+ if start == 0:
239
+ out = out[:, :, layer.stride // 2:, :]
240
+ outs.append(out)
241
+ layer.last = last
242
+ start = limit
243
+ out = torch.cat(outs, dim=2)
244
+ if not self.conv and not last:
245
+ out = F.gelu(out)
246
+ if self.conv:
247
+ return out
248
+ else:
249
+ return out, None
250
+
251
+
252
+ class HDecLayer(nn.Module):
253
+ def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
254
+ freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
255
+ context_freq=True, rewrite=True):
256
+ """
257
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
258
+ """
259
+ super().__init__()
260
+ norm_fn = lambda d: nn.Identity() # noqa
261
+ if norm:
262
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
263
+ if pad:
264
+ pad = kernel_size // 4
265
+ else:
266
+ pad = 0
267
+ self.pad = pad
268
+ self.last = last
269
+ self.freq = freq
270
+ self.chin = chin
271
+ self.empty = empty
272
+ self.stride = stride
273
+ self.kernel_size = kernel_size
274
+ self.norm = norm
275
+ self.context_freq = context_freq
276
+ klass = nn.Conv1d
277
+ klass_tr = nn.ConvTranspose1d
278
+ if freq:
279
+ kernel_size = [kernel_size, 1]
280
+ stride = [stride, 1]
281
+ klass = nn.Conv2d
282
+ klass_tr = nn.ConvTranspose2d
283
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
284
+ self.norm2 = norm_fn(chout)
285
+ if self.empty:
286
+ return
287
+ self.rewrite = None
288
+ if rewrite:
289
+ if context_freq:
290
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
291
+ else:
292
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
293
+ [0, context])
294
+ self.norm1 = norm_fn(2 * chin)
295
+
296
+ self.dconv = None
297
+ if dconv:
298
+ self.dconv = DConv(chin, **dconv_kw)
299
+
300
+ def forward(self, x, skip, length):
301
+ if self.freq and x.dim() == 3:
302
+ B, C, T = x.shape
303
+ x = x.view(B, self.chin, -1, T)
304
+
305
+ if not self.empty:
306
+ x = x + skip
307
+
308
+ if self.rewrite:
309
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
310
+ else:
311
+ y = x
312
+ if self.dconv:
313
+ if self.freq:
314
+ B, C, Fr, T = y.shape
315
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
316
+ y = self.dconv(y)
317
+ if self.freq:
318
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
319
+ else:
320
+ y = x
321
+ assert skip is None
322
+ z = self.norm2(self.conv_tr(y))
323
+ if self.freq:
324
+ if self.pad:
325
+ z = z[..., self.pad:-self.pad, :]
326
+ else:
327
+ z = z[..., self.pad:self.pad + length]
328
+ assert z.shape[-1] == length, (z.shape[-1], length)
329
+ if not self.last:
330
+ z = F.gelu(z)
331
+ return z, y
332
+
333
+
334
+ class HDemucs(nn.Module):
335
+ """
336
+ Spectrogram and hybrid Demucs model.
337
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
338
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
339
+ Frequency layers can still access information across time steps thanks to the DConv residual.
340
+
341
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
342
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
343
+
344
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
345
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
346
+ Open Unmix implementation [Stoter et al. 2019].
347
+
348
+ The loss is always on the temporal domain, by backpropagating through the above
349
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
350
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
351
+ contribution, without changing the one from the waveform, which will lead to worse performance.
352
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
353
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
354
+ hybrid models.
355
+
356
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
357
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
358
+
359
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
360
+ """
361
+ @capture_init
362
+ def __init__(self,
363
+ sources,
364
+ # Channels
365
+ audio_channels=2,
366
+ channels=48,
367
+ channels_time=None,
368
+ growth=2,
369
+ # STFT
370
+ nfft=4096,
371
+ wiener_iters=0,
372
+ end_iters=0,
373
+ wiener_residual=False,
374
+ cac=True,
375
+ # Main structure
376
+ depth=6,
377
+ rewrite=True,
378
+ hybrid=True,
379
+ hybrid_old=False,
380
+ # Frequency branch
381
+ multi_freqs=None,
382
+ multi_freqs_depth=2,
383
+ freq_emb=0.2,
384
+ emb_scale=10,
385
+ emb_smooth=True,
386
+ # Convolutions
387
+ kernel_size=8,
388
+ time_stride=2,
389
+ stride=4,
390
+ context=1,
391
+ context_enc=0,
392
+ # Normalization
393
+ norm_starts=4,
394
+ norm_groups=4,
395
+ # DConv residual branch
396
+ dconv_mode=1,
397
+ dconv_depth=2,
398
+ dconv_comp=4,
399
+ dconv_attn=4,
400
+ dconv_lstm=4,
401
+ dconv_init=1e-4,
402
+ # Weight init
403
+ rescale=0.1,
404
+ # Metadata
405
+ samplerate=44100,
406
+ segment=4 * 10):
407
+
408
+ """
409
+ Args:
410
+ sources (list[str]): list of source names.
411
+ audio_channels (int): input/output audio channels.
412
+ channels (int): initial number of hidden channels.
413
+ channels_time: if not None, use a different `channels` value for the time branch.
414
+ growth: increase the number of hidden channels by this factor at each layer.
415
+ nfft: number of fft bins. Note that changing this require careful computation of
416
+ various shape parameters and will not work out of the box for hybrid models.
417
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
418
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
419
+ wiener_residual: add residual source before wiener filtering.
420
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
421
+ in input and output. no further processing is done before ISTFT.
422
+ depth (int): number of layers in the encoder and in the decoder.
423
+ rewrite (bool): add 1x1 convolution to each layer.
424
+ hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
425
+ hybrid_old: some models trained for MDX had a padding bug. This replicates
426
+ this bug to avoid retraining them.
427
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
428
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
429
+ layers will be wrapped.
430
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
431
+ the actual value controls the weight of the embedding.
432
+ emb_scale: equivalent to scaling the embedding learning rate
433
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
434
+ kernel_size: kernel_size for encoder and decoder layers.
435
+ stride: stride for encoder and decoder layers.
436
+ time_stride: stride for the final time layer, after the merge.
437
+ context: context for 1x1 conv in the decoder.
438
+ context_enc: context for 1x1 conv in the encoder.
439
+ norm_starts: layer at which group norm starts being used.
440
+ decoder layers are numbered in reverse order.
441
+ norm_groups: number of groups for group norm.
442
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
443
+ dconv_depth: depth of residual DConv branch.
444
+ dconv_comp: compression of DConv branch.
445
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
446
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
447
+ dconv_init: initial scale for the DConv branch LayerScale.
448
+ rescale: weight recaling trick
449
+
450
+ """
451
+ super().__init__()
452
+
453
+ self.cac = cac
454
+ self.wiener_residual = wiener_residual
455
+ self.audio_channels = audio_channels
456
+ self.sources = sources
457
+ self.kernel_size = kernel_size
458
+ self.context = context
459
+ self.stride = stride
460
+ self.depth = depth
461
+ self.channels = channels
462
+ self.samplerate = samplerate
463
+ self.segment = segment
464
+
465
+ self.nfft = nfft
466
+ self.hop_length = nfft // 4
467
+ self.wiener_iters = wiener_iters
468
+ self.end_iters = end_iters
469
+ self.freq_emb = None
470
+ self.hybrid = hybrid
471
+ self.hybrid_old = hybrid_old
472
+ if hybrid_old:
473
+ assert hybrid, "hybrid_old must come with hybrid=True"
474
+ if hybrid:
475
+ assert wiener_iters == end_iters
476
+
477
+ self.encoder = nn.ModuleList()
478
+ self.decoder = nn.ModuleList()
479
+
480
+ if hybrid:
481
+ self.tencoder = nn.ModuleList()
482
+ self.tdecoder = nn.ModuleList()
483
+
484
+ chin = audio_channels
485
+ chin_z = chin # number of channels for the freq branch
486
+ if self.cac:
487
+ chin_z *= 2
488
+ chout = channels_time or channels
489
+ chout_z = channels
490
+ freqs = nfft // 2
491
+
492
+ for index in range(depth):
493
+ lstm = index >= dconv_lstm
494
+ attn = index >= dconv_attn
495
+ norm = index >= norm_starts
496
+ freq = freqs > 1
497
+ stri = stride
498
+ ker = kernel_size
499
+ if not freq:
500
+ assert freqs == 1
501
+ ker = time_stride * 2
502
+ stri = time_stride
503
+
504
+ pad = True
505
+ last_freq = False
506
+ if freq and freqs <= kernel_size:
507
+ ker = freqs
508
+ pad = False
509
+ last_freq = True
510
+
511
+ kw = {
512
+ 'kernel_size': ker,
513
+ 'stride': stri,
514
+ 'freq': freq,
515
+ 'pad': pad,
516
+ 'norm': norm,
517
+ 'rewrite': rewrite,
518
+ 'norm_groups': norm_groups,
519
+ 'dconv_kw': {
520
+ 'lstm': lstm,
521
+ 'attn': attn,
522
+ 'depth': dconv_depth,
523
+ 'compress': dconv_comp,
524
+ 'init': dconv_init,
525
+ 'gelu': True,
526
+ }
527
+ }
528
+ kwt = dict(kw)
529
+ kwt['freq'] = 0
530
+ kwt['kernel_size'] = kernel_size
531
+ kwt['stride'] = stride
532
+ kwt['pad'] = True
533
+ kw_dec = dict(kw)
534
+ multi = False
535
+ if multi_freqs and index < multi_freqs_depth:
536
+ multi = True
537
+ kw_dec['context_freq'] = False
538
+
539
+ if last_freq:
540
+ chout_z = max(chout, chout_z)
541
+ chout = chout_z
542
+
543
+ enc = HEncLayer(chin_z, chout_z,
544
+ dconv=dconv_mode & 1, context=context_enc, **kw)
545
+ if hybrid and freq:
546
+ tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
547
+ empty=last_freq, **kwt)
548
+ self.tencoder.append(tenc)
549
+
550
+ if multi:
551
+ enc = MultiWrap(enc, multi_freqs)
552
+ self.encoder.append(enc)
553
+ if index == 0:
554
+ chin = self.audio_channels * len(self.sources)
555
+ chin_z = chin
556
+ if self.cac:
557
+ chin_z *= 2
558
+ dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
559
+ last=index == 0, context=context, **kw_dec)
560
+ if multi:
561
+ dec = MultiWrap(dec, multi_freqs)
562
+ if hybrid and freq:
563
+ tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
564
+ last=index == 0, context=context, **kwt)
565
+ self.tdecoder.insert(0, tdec)
566
+ self.decoder.insert(0, dec)
567
+
568
+ chin = chout
569
+ chin_z = chout_z
570
+ chout = int(growth * chout)
571
+ chout_z = int(growth * chout_z)
572
+ if freq:
573
+ if freqs <= kernel_size:
574
+ freqs = 1
575
+ else:
576
+ freqs //= stride
577
+ if index == 0 and freq_emb:
578
+ self.freq_emb = ScaledEmbedding(
579
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
580
+ self.freq_emb_scale = freq_emb
581
+
582
+ if rescale:
583
+ rescale_module(self, reference=rescale)
584
+
585
+ def _spec(self, x):
586
+ hl = self.hop_length
587
+ nfft = self.nfft
588
+ x0 = x # noqa
589
+
590
+ if self.hybrid:
591
+ # We re-pad the signal in order to keep the property
592
+ # that the size of the output is exactly the size of the input
593
+ # divided by the stride (here hop_length), when divisible.
594
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
595
+ # which is not supported by torch.stft.
596
+ # Having all convolution operations follow this convention allow to easily
597
+ # align the time and frequency branches later on.
598
+ assert hl == nfft // 4
599
+ le = int(math.ceil(x.shape[-1] / hl))
600
+ pad = hl // 2 * 3
601
+ if not self.hybrid_old:
602
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
603
+ else:
604
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
605
+
606
+ z = spectro(x, nfft, hl)[..., :-1, :]
607
+ if self.hybrid:
608
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
609
+ z = z[..., 2:2+le]
610
+ return z
611
+
612
+ def _ispec(self, z, length=None, scale=0):
613
+ hl = self.hop_length // (4 ** scale)
614
+ z = F.pad(z, (0, 0, 0, 1))
615
+ if self.hybrid:
616
+ z = F.pad(z, (2, 2))
617
+ pad = hl // 2 * 3
618
+ if not self.hybrid_old:
619
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
620
+ else:
621
+ le = hl * int(math.ceil(length / hl))
622
+ x = ispectro(z, hl, length=le)
623
+ if not self.hybrid_old:
624
+ x = x[..., pad:pad + length]
625
+ else:
626
+ x = x[..., :length]
627
+ else:
628
+ x = ispectro(z, hl, length)
629
+ return x
630
+
631
+ def _magnitude(self, z):
632
+ # return the magnitude of the spectrogram, except when cac is True,
633
+ # in which case we just move the complex dimension to the channel one.
634
+ if self.cac:
635
+ B, C, Fr, T = z.shape
636
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
637
+ m = m.reshape(B, C * 2, Fr, T)
638
+ else:
639
+ m = z.abs()
640
+ return m
641
+
642
+ def _mask(self, z, m):
643
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
644
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
645
+ niters = self.wiener_iters
646
+ if self.cac:
647
+ B, S, C, Fr, T = m.shape
648
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
649
+ out = torch.view_as_complex(out.contiguous())
650
+ return out
651
+ if self.training:
652
+ niters = self.end_iters
653
+ if niters < 0:
654
+ z = z[:, None]
655
+ return z / (1e-8 + z.abs()) * m
656
+ else:
657
+ return self._wiener(m, z, niters)
658
+
659
+ def _wiener(self, mag_out, mix_stft, niters):
660
+ # apply wiener filtering from OpenUnmix.
661
+ init = mix_stft.dtype
662
+ wiener_win_len = 300
663
+ residual = self.wiener_residual
664
+
665
+ B, S, C, Fq, T = mag_out.shape
666
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
667
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
668
+
669
+ outs = []
670
+ for sample in range(B):
671
+ pos = 0
672
+ out = []
673
+ for pos in range(0, T, wiener_win_len):
674
+ frame = slice(pos, pos + wiener_win_len)
675
+ z_out = wiener(
676
+ mag_out[sample, frame], mix_stft[sample, frame], niters,
677
+ residual=residual)
678
+ out.append(z_out.transpose(-1, -2))
679
+ outs.append(torch.cat(out, dim=0))
680
+ out = torch.view_as_complex(torch.stack(outs, 0))
681
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
682
+ if residual:
683
+ out = out[:, :-1]
684
+ assert list(out.shape) == [B, S, C, Fq, T]
685
+ return out.to(init)
686
+
687
+ def forward(self, mix):
688
+ x = mix
689
+ length = x.shape[-1]
690
+
691
+ z = self._spec(mix)
692
+ mag = self._magnitude(z).to(mix.device)
693
+ x = mag
694
+
695
+ B, C, Fq, T = x.shape
696
+
697
+ # unlike previous Demucs, we always normalize because it is easier.
698
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
699
+ std = x.std(dim=(1, 2, 3), keepdim=True)
700
+ x = (x - mean) / (1e-5 + std)
701
+ # x will be the freq. branch input.
702
+
703
+ if self.hybrid:
704
+ # Prepare the time branch input.
705
+ xt = mix
706
+ meant = xt.mean(dim=(1, 2), keepdim=True)
707
+ stdt = xt.std(dim=(1, 2), keepdim=True)
708
+ xt = (xt - meant) / (1e-5 + stdt)
709
+
710
+ # okay, this is a giant mess I know...
711
+ saved = [] # skip connections, freq.
712
+ saved_t = [] # skip connections, time.
713
+ lengths = [] # saved lengths to properly remove padding, freq branch.
714
+ lengths_t = [] # saved lengths for time branch.
715
+ for idx, encode in enumerate(self.encoder):
716
+ lengths.append(x.shape[-1])
717
+ inject = None
718
+ if self.hybrid and idx < len(self.tencoder):
719
+ # we have not yet merged branches.
720
+ lengths_t.append(xt.shape[-1])
721
+ tenc = self.tencoder[idx]
722
+ xt = tenc(xt)
723
+ if not tenc.empty:
724
+ # save for skip connection
725
+ saved_t.append(xt)
726
+ else:
727
+ # tenc contains just the first conv., so that now time and freq.
728
+ # branches have the same shape and can be merged.
729
+ inject = xt
730
+ x = encode(x, inject)
731
+ if idx == 0 and self.freq_emb is not None:
732
+ # add frequency embedding to allow for non equivariant convolutions
733
+ # over the frequency axis.
734
+ frs = torch.arange(x.shape[-2], device=x.device)
735
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
736
+ x = x + self.freq_emb_scale * emb
737
+
738
+ saved.append(x)
739
+
740
+ x = torch.zeros_like(x)
741
+ if self.hybrid:
742
+ xt = torch.zeros_like(x)
743
+ # initialize everything to zero (signal will go through u-net skips).
744
+
745
+ for idx, decode in enumerate(self.decoder):
746
+ skip = saved.pop(-1)
747
+ x, pre = decode(x, skip, lengths.pop(-1))
748
+ # `pre` contains the output just before final transposed convolution,
749
+ # which is used when the freq. and time branch separate.
750
+
751
+ if self.hybrid:
752
+ offset = self.depth - len(self.tdecoder)
753
+ if self.hybrid and idx >= offset:
754
+ tdec = self.tdecoder[idx - offset]
755
+ length_t = lengths_t.pop(-1)
756
+ if tdec.empty:
757
+ assert pre.shape[2] == 1, pre.shape
758
+ pre = pre[:, :, 0]
759
+ xt, _ = tdec(pre, None, length_t)
760
+ else:
761
+ skip = saved_t.pop(-1)
762
+ xt, _ = tdec(xt, skip, length_t)
763
+
764
+ # Let's make sure we used all stored skip connections.
765
+ assert len(saved) == 0
766
+ assert len(lengths_t) == 0
767
+ assert len(saved_t) == 0
768
+
769
+ S = len(self.sources)
770
+ x = x.view(B, S, -1, Fq, T)
771
+ x = x * std[:, None] + mean[:, None]
772
+
773
+ # to cpu as non-cuda GPUs don't support complex numbers
774
+ # demucs issue #435 ##432
775
+ # NOTE: in this case z already is on cpu
776
+ # TODO: remove this when mps supports complex numbers
777
+
778
+ device_type = x.device.type
779
+ device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type
780
+ x_is_other_gpu = not device_type in ["cuda", "cpu"]
781
+
782
+ if x_is_other_gpu:
783
+ x = x.cpu()
784
+
785
+ zout = self._mask(z, x)
786
+ x = self._ispec(zout, length)
787
+
788
+ # back to other device
789
+ if x_is_other_gpu:
790
+ x = x.to(device_load)
791
+
792
+ if self.hybrid:
793
+ xt = xt.view(B, S, -1, length)
794
+ xt = xt * stdt[:, None] + meant[:, None]
795
+ x = xt + x
796
+ return x
demucs/htdemucs.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+ """
8
+ This code contains the spectrogram and Hybrid version of Demucs.
9
+ """
10
+ import math
11
+
12
+ from .filtering import wiener
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from fractions import Fraction
17
+ from einops import rearrange
18
+
19
+ from .transformer import CrossTransformerEncoder
20
+
21
+ from .demucs import rescale_module
22
+ from .states import capture_init
23
+ from .spec import spectro, ispectro
24
+ from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
25
+
26
+
27
+ class HTDemucs(nn.Module):
28
+ """
29
+ Spectrogram and hybrid Demucs model.
30
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
31
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
32
+ Frequency layers can still access information across time steps thanks to the DConv residual.
33
+
34
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
35
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
36
+
37
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
38
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
39
+ Open Unmix implementation [Stoter et al. 2019].
40
+
41
+ The loss is always on the temporal domain, by backpropagating through the above
42
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
43
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
44
+ contribution, without changing the one from the waveform, which will lead to worse performance.
45
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
46
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
47
+ hybrid models.
48
+
49
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
50
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
51
+
52
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
53
+ """
54
+
55
+ @capture_init
56
+ def __init__(
57
+ self,
58
+ sources,
59
+ # Channels
60
+ audio_channels=2,
61
+ channels=48,
62
+ channels_time=None,
63
+ growth=2,
64
+ # STFT
65
+ nfft=4096,
66
+ wiener_iters=0,
67
+ end_iters=0,
68
+ wiener_residual=False,
69
+ cac=True,
70
+ # Main structure
71
+ depth=4,
72
+ rewrite=True,
73
+ # Frequency branch
74
+ multi_freqs=None,
75
+ multi_freqs_depth=3,
76
+ freq_emb=0.2,
77
+ emb_scale=10,
78
+ emb_smooth=True,
79
+ # Convolutions
80
+ kernel_size=8,
81
+ time_stride=2,
82
+ stride=4,
83
+ context=1,
84
+ context_enc=0,
85
+ # Normalization
86
+ norm_starts=4,
87
+ norm_groups=4,
88
+ # DConv residual branch
89
+ dconv_mode=1,
90
+ dconv_depth=2,
91
+ dconv_comp=8,
92
+ dconv_init=1e-3,
93
+ # Before the Transformer
94
+ bottom_channels=0,
95
+ # Transformer
96
+ t_layers=5,
97
+ t_emb="sin",
98
+ t_hidden_scale=4.0,
99
+ t_heads=8,
100
+ t_dropout=0.0,
101
+ t_max_positions=10000,
102
+ t_norm_in=True,
103
+ t_norm_in_group=False,
104
+ t_group_norm=False,
105
+ t_norm_first=True,
106
+ t_norm_out=True,
107
+ t_max_period=10000.0,
108
+ t_weight_decay=0.0,
109
+ t_lr=None,
110
+ t_layer_scale=True,
111
+ t_gelu=True,
112
+ t_weight_pos_embed=1.0,
113
+ t_sin_random_shift=0,
114
+ t_cape_mean_normalize=True,
115
+ t_cape_augment=True,
116
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
117
+ t_sparse_self_attn=False,
118
+ t_sparse_cross_attn=False,
119
+ t_mask_type="diag",
120
+ t_mask_random_seed=42,
121
+ t_sparse_attn_window=500,
122
+ t_global_window=100,
123
+ t_sparsity=0.95,
124
+ t_auto_sparsity=False,
125
+ # ------ Particuliar parameters
126
+ t_cross_first=False,
127
+ # Weight init
128
+ rescale=0.1,
129
+ # Metadata
130
+ samplerate=44100,
131
+ segment=10,
132
+ use_train_segment=True,
133
+ ):
134
+ """
135
+ Args:
136
+ sources (list[str]): list of source names.
137
+ audio_channels (int): input/output audio channels.
138
+ channels (int): initial number of hidden channels.
139
+ channels_time: if not None, use a different `channels` value for the time branch.
140
+ growth: increase the number of hidden channels by this factor at each layer.
141
+ nfft: number of fft bins. Note that changing this require careful computation of
142
+ various shape parameters and will not work out of the box for hybrid models.
143
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
144
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
145
+ wiener_residual: add residual source before wiener filtering.
146
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
147
+ in input and output. no further processing is done before ISTFT.
148
+ depth (int): number of layers in the encoder and in the decoder.
149
+ rewrite (bool): add 1x1 convolution to each layer.
150
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
151
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
152
+ layers will be wrapped.
153
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
154
+ the actual value controls the weight of the embedding.
155
+ emb_scale: equivalent to scaling the embedding learning rate
156
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
157
+ kernel_size: kernel_size for encoder and decoder layers.
158
+ stride: stride for encoder and decoder layers.
159
+ time_stride: stride for the final time layer, after the merge.
160
+ context: context for 1x1 conv in the decoder.
161
+ context_enc: context for 1x1 conv in the encoder.
162
+ norm_starts: layer at which group norm starts being used.
163
+ decoder layers are numbered in reverse order.
164
+ norm_groups: number of groups for group norm.
165
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
166
+ dconv_depth: depth of residual DConv branch.
167
+ dconv_comp: compression of DConv branch.
168
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
169
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
170
+ dconv_init: initial scale for the DConv branch LayerScale.
171
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
172
+ transformer in order to change the number of channels
173
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
174
+ t_emb: "sin", "cape" or "scaled"
175
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
176
+ for instance if C = 384 (the number of channels in the transformer) and
177
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
178
+ 384 * 4 = 1536
179
+ t_heads: number of heads for the transformer
180
+ t_dropout: dropout in the transformer
181
+ t_max_positions: max_positions for the "scaled" positional embedding, only
182
+ useful if t_emb="scaled"
183
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
184
+ transformer layers
185
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
186
+ timesteps (GroupNorm with group=1)
187
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
188
+ timesteps (GroupNorm with group=1)
189
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
190
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
191
+ t_max_period: (float) denominator in the sinusoidal embedding expression
192
+ t_weight_decay: (float) weight decay for the transformer
193
+ t_lr: (float) specific learning rate for the transformer
194
+ t_layer_scale: (bool) Layer Scale for the transformer
195
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
196
+ t_weight_pos_embed: (float) weighting of the positional embedding
197
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
198
+ see: https://arxiv.org/abs/2106.03143
199
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
200
+ during the inference, see: https://arxiv.org/abs/2106.03143
201
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
202
+ see: https://arxiv.org/abs/2106.03143
203
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
204
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
205
+ unless you designed really specific masks)
206
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
207
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
208
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
209
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
210
+ that generated the random part of the mask
211
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
212
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
213
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
214
+ and mask[:, :t_global_window] will be True
215
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
216
+ level of the random part of the mask.
217
+ t_cross_first: (bool) if True cross attention is the first layer of the
218
+ transformer (False seems to be better)
219
+ rescale: weight rescaling trick
220
+ use_train_segment: (bool) if True, the actual size that is used during the
221
+ training is used during inference.
222
+ """
223
+ super().__init__()
224
+ self.cac = cac
225
+ self.wiener_residual = wiener_residual
226
+ self.audio_channels = audio_channels
227
+ self.sources = sources
228
+ self.kernel_size = kernel_size
229
+ self.context = context
230
+ self.stride = stride
231
+ self.depth = depth
232
+ self.bottom_channels = bottom_channels
233
+ self.channels = channels
234
+ self.samplerate = samplerate
235
+ self.segment = segment
236
+ self.use_train_segment = use_train_segment
237
+ self.nfft = nfft
238
+ self.hop_length = nfft // 4
239
+ self.wiener_iters = wiener_iters
240
+ self.end_iters = end_iters
241
+ self.freq_emb = None
242
+ assert wiener_iters == end_iters
243
+
244
+ self.encoder = nn.ModuleList()
245
+ self.decoder = nn.ModuleList()
246
+
247
+ self.tencoder = nn.ModuleList()
248
+ self.tdecoder = nn.ModuleList()
249
+
250
+ chin = audio_channels
251
+ chin_z = chin # number of channels for the freq branch
252
+ if self.cac:
253
+ chin_z *= 2
254
+ chout = channels_time or channels
255
+ chout_z = channels
256
+ freqs = nfft // 2
257
+
258
+ for index in range(depth):
259
+ norm = index >= norm_starts
260
+ freq = freqs > 1
261
+ stri = stride
262
+ ker = kernel_size
263
+ if not freq:
264
+ assert freqs == 1
265
+ ker = time_stride * 2
266
+ stri = time_stride
267
+
268
+ pad = True
269
+ last_freq = False
270
+ if freq and freqs <= kernel_size:
271
+ ker = freqs
272
+ pad = False
273
+ last_freq = True
274
+
275
+ kw = {
276
+ "kernel_size": ker,
277
+ "stride": stri,
278
+ "freq": freq,
279
+ "pad": pad,
280
+ "norm": norm,
281
+ "rewrite": rewrite,
282
+ "norm_groups": norm_groups,
283
+ "dconv_kw": {
284
+ "depth": dconv_depth,
285
+ "compress": dconv_comp,
286
+ "init": dconv_init,
287
+ "gelu": True,
288
+ },
289
+ }
290
+ kwt = dict(kw)
291
+ kwt["freq"] = 0
292
+ kwt["kernel_size"] = kernel_size
293
+ kwt["stride"] = stride
294
+ kwt["pad"] = True
295
+ kw_dec = dict(kw)
296
+ multi = False
297
+ if multi_freqs and index < multi_freqs_depth:
298
+ multi = True
299
+ kw_dec["context_freq"] = False
300
+
301
+ if last_freq:
302
+ chout_z = max(chout, chout_z)
303
+ chout = chout_z
304
+
305
+ enc = HEncLayer(
306
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
307
+ )
308
+ if freq:
309
+ tenc = HEncLayer(
310
+ chin,
311
+ chout,
312
+ dconv=dconv_mode & 1,
313
+ context=context_enc,
314
+ empty=last_freq,
315
+ **kwt
316
+ )
317
+ self.tencoder.append(tenc)
318
+
319
+ if multi:
320
+ enc = MultiWrap(enc, multi_freqs)
321
+ self.encoder.append(enc)
322
+ if index == 0:
323
+ chin = self.audio_channels * len(self.sources)
324
+ chin_z = chin
325
+ if self.cac:
326
+ chin_z *= 2
327
+ dec = HDecLayer(
328
+ chout_z,
329
+ chin_z,
330
+ dconv=dconv_mode & 2,
331
+ last=index == 0,
332
+ context=context,
333
+ **kw_dec
334
+ )
335
+ if multi:
336
+ dec = MultiWrap(dec, multi_freqs)
337
+ if freq:
338
+ tdec = HDecLayer(
339
+ chout,
340
+ chin,
341
+ dconv=dconv_mode & 2,
342
+ empty=last_freq,
343
+ last=index == 0,
344
+ context=context,
345
+ **kwt
346
+ )
347
+ self.tdecoder.insert(0, tdec)
348
+ self.decoder.insert(0, dec)
349
+
350
+ chin = chout
351
+ chin_z = chout_z
352
+ chout = int(growth * chout)
353
+ chout_z = int(growth * chout_z)
354
+ if freq:
355
+ if freqs <= kernel_size:
356
+ freqs = 1
357
+ else:
358
+ freqs //= stride
359
+ if index == 0 and freq_emb:
360
+ self.freq_emb = ScaledEmbedding(
361
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
362
+ )
363
+ self.freq_emb_scale = freq_emb
364
+
365
+ if rescale:
366
+ rescale_module(self, reference=rescale)
367
+
368
+ transformer_channels = channels * growth ** (depth - 1)
369
+ if bottom_channels:
370
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
371
+ self.channel_downsampler = nn.Conv1d(
372
+ bottom_channels, transformer_channels, 1
373
+ )
374
+ self.channel_upsampler_t = nn.Conv1d(
375
+ transformer_channels, bottom_channels, 1
376
+ )
377
+ self.channel_downsampler_t = nn.Conv1d(
378
+ bottom_channels, transformer_channels, 1
379
+ )
380
+
381
+ transformer_channels = bottom_channels
382
+
383
+ if t_layers > 0:
384
+ self.crosstransformer = CrossTransformerEncoder(
385
+ dim=transformer_channels,
386
+ emb=t_emb,
387
+ hidden_scale=t_hidden_scale,
388
+ num_heads=t_heads,
389
+ num_layers=t_layers,
390
+ cross_first=t_cross_first,
391
+ dropout=t_dropout,
392
+ max_positions=t_max_positions,
393
+ norm_in=t_norm_in,
394
+ norm_in_group=t_norm_in_group,
395
+ group_norm=t_group_norm,
396
+ norm_first=t_norm_first,
397
+ norm_out=t_norm_out,
398
+ max_period=t_max_period,
399
+ weight_decay=t_weight_decay,
400
+ lr=t_lr,
401
+ layer_scale=t_layer_scale,
402
+ gelu=t_gelu,
403
+ sin_random_shift=t_sin_random_shift,
404
+ weight_pos_embed=t_weight_pos_embed,
405
+ cape_mean_normalize=t_cape_mean_normalize,
406
+ cape_augment=t_cape_augment,
407
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
408
+ sparse_self_attn=t_sparse_self_attn,
409
+ sparse_cross_attn=t_sparse_cross_attn,
410
+ mask_type=t_mask_type,
411
+ mask_random_seed=t_mask_random_seed,
412
+ sparse_attn_window=t_sparse_attn_window,
413
+ global_window=t_global_window,
414
+ sparsity=t_sparsity,
415
+ auto_sparsity=t_auto_sparsity,
416
+ )
417
+ else:
418
+ self.crosstransformer = None
419
+
420
+ def _spec(self, x):
421
+ hl = self.hop_length
422
+ nfft = self.nfft
423
+ x0 = x # noqa
424
+
425
+ # We re-pad the signal in order to keep the property
426
+ # that the size of the output is exactly the size of the input
427
+ # divided by the stride (here hop_length), when divisible.
428
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
429
+ # which is not supported by torch.stft.
430
+ # Having all convolution operations follow this convention allow to easily
431
+ # align the time and frequency branches later on.
432
+ assert hl == nfft // 4
433
+ le = int(math.ceil(x.shape[-1] / hl))
434
+ pad = hl // 2 * 3
435
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
436
+
437
+ z = spectro(x, nfft, hl)[..., :-1, :]
438
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
439
+ z = z[..., 2: 2 + le]
440
+ return z
441
+
442
+ def _ispec(self, z, length=None, scale=0):
443
+ hl = self.hop_length // (4**scale)
444
+ z = F.pad(z, (0, 0, 0, 1))
445
+ z = F.pad(z, (2, 2))
446
+ pad = hl // 2 * 3
447
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
448
+ x = ispectro(z, hl, length=le)
449
+ x = x[..., pad: pad + length]
450
+ return x
451
+
452
+ def _magnitude(self, z):
453
+ # return the magnitude of the spectrogram, except when cac is True,
454
+ # in which case we just move the complex dimension to the channel one.
455
+ if self.cac:
456
+ B, C, Fr, T = z.shape
457
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
458
+ m = m.reshape(B, C * 2, Fr, T)
459
+ else:
460
+ m = z.abs()
461
+ return m
462
+
463
+ def _mask(self, z, m):
464
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
465
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
466
+ niters = self.wiener_iters
467
+ if self.cac:
468
+ B, S, C, Fr, T = m.shape
469
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
470
+ out = torch.view_as_complex(out.contiguous())
471
+ return out
472
+ if self.training:
473
+ niters = self.end_iters
474
+ if niters < 0:
475
+ z = z[:, None]
476
+ return z / (1e-8 + z.abs()) * m
477
+ else:
478
+ return self._wiener(m, z, niters)
479
+
480
+ def _wiener(self, mag_out, mix_stft, niters):
481
+ # apply wiener filtering from OpenUnmix.
482
+ init = mix_stft.dtype
483
+ wiener_win_len = 300
484
+ residual = self.wiener_residual
485
+
486
+ B, S, C, Fq, T = mag_out.shape
487
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
488
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
489
+
490
+ outs = []
491
+ for sample in range(B):
492
+ pos = 0
493
+ out = []
494
+ for pos in range(0, T, wiener_win_len):
495
+ frame = slice(pos, pos + wiener_win_len)
496
+ z_out = wiener(
497
+ mag_out[sample, frame],
498
+ mix_stft[sample, frame],
499
+ niters,
500
+ residual=residual,
501
+ )
502
+ out.append(z_out.transpose(-1, -2))
503
+ outs.append(torch.cat(out, dim=0))
504
+ out = torch.view_as_complex(torch.stack(outs, 0))
505
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
506
+ if residual:
507
+ out = out[:, :-1]
508
+ assert list(out.shape) == [B, S, C, Fq, T]
509
+ return out.to(init)
510
+
511
+ def valid_length(self, length: int):
512
+ """
513
+ Return a length that is appropriate for evaluation.
514
+ In our case, always return the training length, unless
515
+ it is smaller than the given length, in which case this
516
+ raises an error.
517
+ """
518
+ if not self.use_train_segment:
519
+ return length
520
+ training_length = int(self.segment * self.samplerate)
521
+ if training_length < length:
522
+ raise ValueError(
523
+ f"Given length {length} is longer than "
524
+ f"training length {training_length}")
525
+ return training_length
526
+
527
+ def forward(self, mix):
528
+ length = mix.shape[-1]
529
+ length_pre_pad = None
530
+ if self.use_train_segment:
531
+ if self.training:
532
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
533
+ else:
534
+ training_length = int(self.segment * self.samplerate)
535
+ if mix.shape[-1] < training_length:
536
+ length_pre_pad = mix.shape[-1]
537
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
538
+ z = self._spec(mix)
539
+ mag = self._magnitude(z).to(mix.device)
540
+ x = mag
541
+
542
+ B, C, Fq, T = x.shape
543
+
544
+ # unlike previous Demucs, we always normalize because it is easier.
545
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
546
+ std = x.std(dim=(1, 2, 3), keepdim=True)
547
+ x = (x - mean) / (1e-5 + std)
548
+ # x will be the freq. branch input.
549
+
550
+ # Prepare the time branch input.
551
+ xt = mix
552
+ meant = xt.mean(dim=(1, 2), keepdim=True)
553
+ stdt = xt.std(dim=(1, 2), keepdim=True)
554
+ xt = (xt - meant) / (1e-5 + stdt)
555
+
556
+ # okay, this is a giant mess I know...
557
+ saved = [] # skip connections, freq.
558
+ saved_t = [] # skip connections, time.
559
+ lengths = [] # saved lengths to properly remove padding, freq branch.
560
+ lengths_t = [] # saved lengths for time branch.
561
+ for idx, encode in enumerate(self.encoder):
562
+ lengths.append(x.shape[-1])
563
+ inject = None
564
+ if idx < len(self.tencoder):
565
+ # we have not yet merged branches.
566
+ lengths_t.append(xt.shape[-1])
567
+ tenc = self.tencoder[idx]
568
+ xt = tenc(xt)
569
+ if not tenc.empty:
570
+ # save for skip connection
571
+ saved_t.append(xt)
572
+ else:
573
+ # tenc contains just the first conv., so that now time and freq.
574
+ # branches have the same shape and can be merged.
575
+ inject = xt
576
+ x = encode(x, inject)
577
+ if idx == 0 and self.freq_emb is not None:
578
+ # add frequency embedding to allow for non equivariant convolutions
579
+ # over the frequency axis.
580
+ frs = torch.arange(x.shape[-2], device=x.device)
581
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
582
+ x = x + self.freq_emb_scale * emb
583
+
584
+ saved.append(x)
585
+ if self.crosstransformer:
586
+ if self.bottom_channels:
587
+ b, c, f, t = x.shape
588
+ x = rearrange(x, "b c f t-> b c (f t)")
589
+ x = self.channel_upsampler(x)
590
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
591
+ xt = self.channel_upsampler_t(xt)
592
+
593
+ x, xt = self.crosstransformer(x, xt)
594
+
595
+ if self.bottom_channels:
596
+ x = rearrange(x, "b c f t-> b c (f t)")
597
+ x = self.channel_downsampler(x)
598
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
599
+ xt = self.channel_downsampler_t(xt)
600
+
601
+ for idx, decode in enumerate(self.decoder):
602
+ skip = saved.pop(-1)
603
+ x, pre = decode(x, skip, lengths.pop(-1))
604
+ # `pre` contains the output just before final transposed convolution,
605
+ # which is used when the freq. and time branch separate.
606
+
607
+ offset = self.depth - len(self.tdecoder)
608
+ if idx >= offset:
609
+ tdec = self.tdecoder[idx - offset]
610
+ length_t = lengths_t.pop(-1)
611
+ if tdec.empty:
612
+ assert pre.shape[2] == 1, pre.shape
613
+ pre = pre[:, :, 0]
614
+ xt, _ = tdec(pre, None, length_t)
615
+ else:
616
+ skip = saved_t.pop(-1)
617
+ xt, _ = tdec(xt, skip, length_t)
618
+
619
+ # Let's make sure we used all stored skip connections.
620
+ assert len(saved) == 0
621
+ assert len(lengths_t) == 0
622
+ assert len(saved_t) == 0
623
+
624
+ S = len(self.sources)
625
+ x = x.view(B, S, -1, Fq, T)
626
+ x = x * std[:, None] + mean[:, None]
627
+
628
+ # to cpu as non-cuda GPUs don't support complex numbers
629
+ # demucs issue #435 ##432
630
+ # NOTE: in this case z already is on cpu
631
+ # TODO: remove this when mps supports complex numbers
632
+
633
+ device_type = x.device.type
634
+ device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type
635
+ x_is_other_gpu = not device_type in ["cuda", "cpu"]
636
+
637
+ if x_is_other_gpu:
638
+ x = x.cpu()
639
+
640
+ zout = self._mask(z, x)
641
+ if self.use_train_segment:
642
+ if self.training:
643
+ x = self._ispec(zout, length)
644
+ else:
645
+ x = self._ispec(zout, training_length)
646
+ else:
647
+ x = self._ispec(zout, length)
648
+
649
+ # back to other device
650
+ if x_is_other_gpu:
651
+ x = x.to(device_load)
652
+
653
+ if self.use_train_segment:
654
+ if self.training:
655
+ xt = xt.view(B, S, -1, length)
656
+ else:
657
+ xt = xt.view(B, S, -1, training_length)
658
+ else:
659
+ xt = xt.view(B, S, -1, length)
660
+ xt = xt * stdt[:, None] + meant[:, None]
661
+ x = xt + x
662
+ if length_pre_pad:
663
+ x = x[..., :length_pre_pad]
664
+ return x
demucs/model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch as th
10
+ from torch import nn
11
+
12
+ from .utils import capture_init, center_trim
13
+
14
+
15
+ class BLSTM(nn.Module):
16
+ def __init__(self, dim, layers=1):
17
+ super().__init__()
18
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
19
+ self.linear = nn.Linear(2 * dim, dim)
20
+
21
+ def forward(self, x):
22
+ x = x.permute(2, 0, 1)
23
+ x = self.lstm(x)[0]
24
+ x = self.linear(x)
25
+ x = x.permute(1, 2, 0)
26
+ return x
27
+
28
+
29
+ def rescale_conv(conv, reference):
30
+ std = conv.weight.std().detach()
31
+ scale = (std / reference)**0.5
32
+ conv.weight.data /= scale
33
+ if conv.bias is not None:
34
+ conv.bias.data /= scale
35
+
36
+
37
+ def rescale_module(module, reference):
38
+ for sub in module.modules():
39
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
40
+ rescale_conv(sub, reference)
41
+
42
+
43
+ def upsample(x, stride):
44
+ """
45
+ Linear upsampling, the output will be `stride` times longer.
46
+ """
47
+ batch, channels, time = x.size()
48
+ weight = th.arange(stride, device=x.device, dtype=th.float) / stride
49
+ x = x.view(batch, channels, time, 1)
50
+ out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight
51
+ return out.reshape(batch, channels, -1)
52
+
53
+
54
+ def downsample(x, stride):
55
+ """
56
+ Downsample x by decimation.
57
+ """
58
+ return x[:, :, ::stride]
59
+
60
+
61
+ class Demucs(nn.Module):
62
+ @capture_init
63
+ def __init__(self,
64
+ sources=4,
65
+ audio_channels=2,
66
+ channels=64,
67
+ depth=6,
68
+ rewrite=True,
69
+ glu=True,
70
+ upsample=False,
71
+ rescale=0.1,
72
+ kernel_size=8,
73
+ stride=4,
74
+ growth=2.,
75
+ lstm_layers=2,
76
+ context=3,
77
+ samplerate=44100):
78
+ """
79
+ Args:
80
+ sources (int): number of sources to separate
81
+ audio_channels (int): stereo or mono
82
+ channels (int): first convolution channels
83
+ depth (int): number of encoder/decoder layers
84
+ rewrite (bool): add 1x1 convolution to each encoder layer
85
+ and a convolution to each decoder layer.
86
+ For the decoder layer, `context` gives the kernel size.
87
+ glu (bool): use glu instead of ReLU
88
+ upsample (bool): use linear upsampling with convolutions
89
+ Wave-U-Net style, instead of transposed convolutions
90
+ rescale (int): rescale initial weights of convolutions
91
+ to get their standard deviation closer to `rescale`
92
+ kernel_size (int): kernel size for convolutions
93
+ stride (int): stride for convolutions
94
+ growth (float): multiply (resp divide) number of channels by that
95
+ for each layer of the encoder (resp decoder)
96
+ lstm_layers (int): number of lstm layers, 0 = no lstm
97
+ context (int): kernel size of the convolution in the
98
+ decoder before the transposed convolution. If > 1,
99
+ will provide some context from neighboring time
100
+ steps.
101
+ """
102
+
103
+ super().__init__()
104
+ self.audio_channels = audio_channels
105
+ self.sources = sources
106
+ self.kernel_size = kernel_size
107
+ self.context = context
108
+ self.stride = stride
109
+ self.depth = depth
110
+ self.upsample = upsample
111
+ self.channels = channels
112
+ self.samplerate = samplerate
113
+
114
+ self.encoder = nn.ModuleList()
115
+ self.decoder = nn.ModuleList()
116
+
117
+ self.final = None
118
+ if upsample:
119
+ self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1)
120
+ stride = 1
121
+
122
+ if glu:
123
+ activation = nn.GLU(dim=1)
124
+ ch_scale = 2
125
+ else:
126
+ activation = nn.ReLU()
127
+ ch_scale = 1
128
+ in_channels = audio_channels
129
+ for index in range(depth):
130
+ encode = []
131
+ encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
132
+ if rewrite:
133
+ encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
134
+ self.encoder.append(nn.Sequential(*encode))
135
+
136
+ decode = []
137
+ if index > 0:
138
+ out_channels = in_channels
139
+ else:
140
+ if upsample:
141
+ out_channels = channels
142
+ else:
143
+ out_channels = sources * audio_channels
144
+ if rewrite:
145
+ decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
146
+ if upsample:
147
+ decode += [
148
+ nn.Conv1d(channels, out_channels, kernel_size, stride=1),
149
+ ]
150
+ else:
151
+ decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
152
+ if index > 0:
153
+ decode.append(nn.ReLU())
154
+ self.decoder.insert(0, nn.Sequential(*decode))
155
+ in_channels = channels
156
+ channels = int(growth * channels)
157
+
158
+ channels = in_channels
159
+
160
+ if lstm_layers:
161
+ self.lstm = BLSTM(channels, lstm_layers)
162
+ else:
163
+ self.lstm = None
164
+
165
+ if rescale:
166
+ rescale_module(self, reference=rescale)
167
+
168
+ def valid_length(self, length):
169
+ """
170
+ Return the nearest valid length to use with the model so that
171
+ there is no time steps left over in a convolutions, e.g. for all
172
+ layers, size of the input - kernel_size % stride = 0.
173
+
174
+ If the mixture has a valid length, the estimated sources
175
+ will have exactly the same length when context = 1. If context > 1,
176
+ the two signals can be center trimmed to match.
177
+
178
+ For training, extracts should have a valid length.For evaluation
179
+ on full tracks we recommend passing `pad = True` to :method:`forward`.
180
+ """
181
+ for _ in range(self.depth):
182
+ if self.upsample:
183
+ length = math.ceil(length / self.stride) + self.kernel_size - 1
184
+ else:
185
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
186
+ length = max(1, length)
187
+ length += self.context - 1
188
+ for _ in range(self.depth):
189
+ if self.upsample:
190
+ length = length * self.stride + self.kernel_size - 1
191
+ else:
192
+ length = (length - 1) * self.stride + self.kernel_size
193
+
194
+ return int(length)
195
+
196
+ def forward(self, mix):
197
+ x = mix
198
+ saved = [x]
199
+ for encode in self.encoder:
200
+ x = encode(x)
201
+ saved.append(x)
202
+ if self.upsample:
203
+ x = downsample(x, self.stride)
204
+ if self.lstm:
205
+ x = self.lstm(x)
206
+ for decode in self.decoder:
207
+ if self.upsample:
208
+ x = upsample(x, stride=self.stride)
209
+ skip = center_trim(saved.pop(-1), x)
210
+ x = x + skip
211
+ x = decode(x)
212
+ if self.final:
213
+ skip = center_trim(saved.pop(-1), x)
214
+ x = th.cat([x, skip], dim=1)
215
+ x = self.final(x)
216
+
217
+ x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1))
218
+ return x
demucs/model_v2.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import julius
10
+ from torch import nn
11
+ from .tasnet_v2 import ConvTasNet
12
+
13
+ from .utils import capture_init, center_trim
14
+
15
+
16
+ class BLSTM(nn.Module):
17
+ def __init__(self, dim, layers=1):
18
+ super().__init__()
19
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
20
+ self.linear = nn.Linear(2 * dim, dim)
21
+
22
+ def forward(self, x):
23
+ x = x.permute(2, 0, 1)
24
+ x = self.lstm(x)[0]
25
+ x = self.linear(x)
26
+ x = x.permute(1, 2, 0)
27
+ return x
28
+
29
+
30
+ def rescale_conv(conv, reference):
31
+ std = conv.weight.std().detach()
32
+ scale = (std / reference)**0.5
33
+ conv.weight.data /= scale
34
+ if conv.bias is not None:
35
+ conv.bias.data /= scale
36
+
37
+
38
+ def rescale_module(module, reference):
39
+ for sub in module.modules():
40
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
41
+ rescale_conv(sub, reference)
42
+
43
+ def auto_load_demucs_model_v2(sources, demucs_model_name):
44
+
45
+ if '48' in demucs_model_name:
46
+ channels=48
47
+ elif 'unittest' in demucs_model_name:
48
+ channels=4
49
+ else:
50
+ channels=64
51
+
52
+ if 'tasnet' in demucs_model_name:
53
+ init_demucs_model = ConvTasNet(sources, X=10)
54
+ else:
55
+ init_demucs_model = Demucs(sources, channels=channels)
56
+
57
+ return init_demucs_model
58
+
59
+ class Demucs(nn.Module):
60
+ @capture_init
61
+ def __init__(self,
62
+ sources,
63
+ audio_channels=2,
64
+ channels=64,
65
+ depth=6,
66
+ rewrite=True,
67
+ glu=True,
68
+ rescale=0.1,
69
+ resample=True,
70
+ kernel_size=8,
71
+ stride=4,
72
+ growth=2.,
73
+ lstm_layers=2,
74
+ context=3,
75
+ normalize=False,
76
+ samplerate=44100,
77
+ segment_length=4 * 10 * 44100):
78
+ """
79
+ Args:
80
+ sources (list[str]): list of source names
81
+ audio_channels (int): stereo or mono
82
+ channels (int): first convolution channels
83
+ depth (int): number of encoder/decoder layers
84
+ rewrite (bool): add 1x1 convolution to each encoder layer
85
+ and a convolution to each decoder layer.
86
+ For the decoder layer, `context` gives the kernel size.
87
+ glu (bool): use glu instead of ReLU
88
+ resample_input (bool): upsample x2 the input and downsample /2 the output.
89
+ rescale (int): rescale initial weights of convolutions
90
+ to get their standard deviation closer to `rescale`
91
+ kernel_size (int): kernel size for convolutions
92
+ stride (int): stride for convolutions
93
+ growth (float): multiply (resp divide) number of channels by that
94
+ for each layer of the encoder (resp decoder)
95
+ lstm_layers (int): number of lstm layers, 0 = no lstm
96
+ context (int): kernel size of the convolution in the
97
+ decoder before the transposed convolution. If > 1,
98
+ will provide some context from neighboring time
99
+ steps.
100
+ samplerate (int): stored as meta information for easing
101
+ future evaluations of the model.
102
+ segment_length (int): stored as meta information for easing
103
+ future evaluations of the model. Length of the segments on which
104
+ the model was trained.
105
+ """
106
+
107
+ super().__init__()
108
+ self.audio_channels = audio_channels
109
+ self.sources = sources
110
+ self.kernel_size = kernel_size
111
+ self.context = context
112
+ self.stride = stride
113
+ self.depth = depth
114
+ self.resample = resample
115
+ self.channels = channels
116
+ self.normalize = normalize
117
+ self.samplerate = samplerate
118
+ self.segment_length = segment_length
119
+
120
+ self.encoder = nn.ModuleList()
121
+ self.decoder = nn.ModuleList()
122
+
123
+ if glu:
124
+ activation = nn.GLU(dim=1)
125
+ ch_scale = 2
126
+ else:
127
+ activation = nn.ReLU()
128
+ ch_scale = 1
129
+ in_channels = audio_channels
130
+ for index in range(depth):
131
+ encode = []
132
+ encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
133
+ if rewrite:
134
+ encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
135
+ self.encoder.append(nn.Sequential(*encode))
136
+
137
+ decode = []
138
+ if index > 0:
139
+ out_channels = in_channels
140
+ else:
141
+ out_channels = len(self.sources) * audio_channels
142
+ if rewrite:
143
+ decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
144
+ decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
145
+ if index > 0:
146
+ decode.append(nn.ReLU())
147
+ self.decoder.insert(0, nn.Sequential(*decode))
148
+ in_channels = channels
149
+ channels = int(growth * channels)
150
+
151
+ channels = in_channels
152
+
153
+ if lstm_layers:
154
+ self.lstm = BLSTM(channels, lstm_layers)
155
+ else:
156
+ self.lstm = None
157
+
158
+ if rescale:
159
+ rescale_module(self, reference=rescale)
160
+
161
+ def valid_length(self, length):
162
+ """
163
+ Return the nearest valid length to use with the model so that
164
+ there is no time steps left over in a convolutions, e.g. for all
165
+ layers, size of the input - kernel_size % stride = 0.
166
+
167
+ If the mixture has a valid length, the estimated sources
168
+ will have exactly the same length when context = 1. If context > 1,
169
+ the two signals can be center trimmed to match.
170
+
171
+ For training, extracts should have a valid length.For evaluation
172
+ on full tracks we recommend passing `pad = True` to :method:`forward`.
173
+ """
174
+ if self.resample:
175
+ length *= 2
176
+ for _ in range(self.depth):
177
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
178
+ length = max(1, length)
179
+ length += self.context - 1
180
+ for _ in range(self.depth):
181
+ length = (length - 1) * self.stride + self.kernel_size
182
+
183
+ if self.resample:
184
+ length = math.ceil(length / 2)
185
+ return int(length)
186
+
187
+ def forward(self, mix):
188
+ x = mix
189
+
190
+ if self.normalize:
191
+ mono = mix.mean(dim=1, keepdim=True)
192
+ mean = mono.mean(dim=-1, keepdim=True)
193
+ std = mono.std(dim=-1, keepdim=True)
194
+ else:
195
+ mean = 0
196
+ std = 1
197
+
198
+ x = (x - mean) / (1e-5 + std)
199
+
200
+ if self.resample:
201
+ x = julius.resample_frac(x, 1, 2)
202
+
203
+ saved = []
204
+ for encode in self.encoder:
205
+ x = encode(x)
206
+ saved.append(x)
207
+ if self.lstm:
208
+ x = self.lstm(x)
209
+ for decode in self.decoder:
210
+ skip = center_trim(saved.pop(-1), x)
211
+ x = x + skip
212
+ x = decode(x)
213
+
214
+ if self.resample:
215
+ x = julius.resample_frac(x, 2, 1)
216
+ x = x * std + mean
217
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
218
+ return x
demucs/pretrained.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Loading pretrained models.
7
+ """
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ import typing as tp
12
+
13
+ #from dora.log import fatal
14
+
15
+ import logging
16
+
17
+ from diffq import DiffQuantizer
18
+ import torch.hub
19
+
20
+ from .model import Demucs
21
+ from .tasnet_v2 import ConvTasNet
22
+ from .utils import set_state
23
+
24
+ from .hdemucs import HDemucs
25
+ from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
26
+
27
+ logger = logging.getLogger(__name__)
28
+ ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
29
+ REMOTE_ROOT = Path(__file__).parent / 'remote'
30
+
31
+ SOURCES = ["drums", "bass", "other", "vocals"]
32
+
33
+
34
+ def demucs_unittest():
35
+ model = HDemucs(channels=4, sources=SOURCES)
36
+ return model
37
+
38
+
39
+ def add_model_flags(parser):
40
+ group = parser.add_mutually_exclusive_group(required=False)
41
+ group.add_argument("-s", "--sig", help="Locally trained XP signature.")
42
+ group.add_argument("-n", "--name", default="mdx_extra_q",
43
+ help="Pretrained model name or signature. Default is mdx_extra_q.")
44
+ parser.add_argument("--repo", type=Path,
45
+ help="Folder containing all pre-trained models for use with -n.")
46
+
47
+
48
+ def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
49
+ root: str = ''
50
+ models: tp.Dict[str, str] = {}
51
+ for line in remote_file_list.read_text().split('\n'):
52
+ line = line.strip()
53
+ if line.startswith('#'):
54
+ continue
55
+ elif line.startswith('root:'):
56
+ root = line.split(':', 1)[1].strip()
57
+ else:
58
+ sig = line.split('-', 1)[0]
59
+ assert sig not in models
60
+ models[sig] = ROOT_URL + root + line
61
+ return models
62
+
63
+ def get_model(name: str,
64
+ repo: tp.Optional[Path] = None):
65
+ """`name` must be a bag of models name or a pretrained signature
66
+ from the remote AWS model repo or the specified local repo if `repo` is not None.
67
+ """
68
+ if name == 'demucs_unittest':
69
+ return demucs_unittest()
70
+ model_repo: ModelOnlyRepo
71
+ if repo is None:
72
+ models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
73
+ model_repo = RemoteRepo(models)
74
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
75
+ else:
76
+ if not repo.is_dir():
77
+ fatal(f"{repo} must exist and be a directory.")
78
+ model_repo = LocalRepo(repo)
79
+ bag_repo = BagOnlyRepo(repo, model_repo)
80
+ any_repo = AnyModelRepo(model_repo, bag_repo)
81
+ model = any_repo.get_model(name)
82
+ model.eval()
83
+ return model
84
+
85
+ def get_model_from_args(args):
86
+ """
87
+ Load local model package or pre-trained model.
88
+ """
89
+ return get_model(name=args.name, repo=args.repo)
90
+
91
+ logger = logging.getLogger(__name__)
92
+ ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
93
+
94
+ PRETRAINED_MODELS = {
95
+ 'demucs': 'e07c671f',
96
+ 'demucs48_hq': '28a1282c',
97
+ 'demucs_extra': '3646af93',
98
+ 'demucs_quantized': '07afea75',
99
+ 'tasnet': 'beb46fac',
100
+ 'tasnet_extra': 'df3777b2',
101
+ 'demucs_unittest': '09ebc15f',
102
+ }
103
+
104
+ SOURCES = ["drums", "bass", "other", "vocals"]
105
+
106
+
107
+ def get_url(name):
108
+ sig = PRETRAINED_MODELS[name]
109
+ return ROOT + name + "-" + sig[:8] + ".th"
110
+
111
+ def is_pretrained(name):
112
+ return name in PRETRAINED_MODELS
113
+
114
+
115
+ def load_pretrained(name):
116
+ if name == "demucs":
117
+ return demucs(pretrained=True)
118
+ elif name == "demucs48_hq":
119
+ return demucs(pretrained=True, hq=True, channels=48)
120
+ elif name == "demucs_extra":
121
+ return demucs(pretrained=True, extra=True)
122
+ elif name == "demucs_quantized":
123
+ return demucs(pretrained=True, quantized=True)
124
+ elif name == "demucs_unittest":
125
+ return demucs_unittest(pretrained=True)
126
+ elif name == "tasnet":
127
+ return tasnet(pretrained=True)
128
+ elif name == "tasnet_extra":
129
+ return tasnet(pretrained=True, extra=True)
130
+ else:
131
+ raise ValueError(f"Invalid pretrained name {name}")
132
+
133
+
134
+ def _load_state(name, model, quantizer=None):
135
+ url = get_url(name)
136
+ state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
137
+ set_state(model, quantizer, state)
138
+ if quantizer:
139
+ quantizer.detach()
140
+
141
+
142
+ def demucs_unittest(pretrained=True):
143
+ model = Demucs(channels=4, sources=SOURCES)
144
+ if pretrained:
145
+ _load_state('demucs_unittest', model)
146
+ return model
147
+
148
+
149
+ def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
150
+ if not pretrained and (extra or quantized or hq):
151
+ raise ValueError("if extra or quantized is True, pretrained must be True.")
152
+ model = Demucs(sources=SOURCES, channels=channels)
153
+ if pretrained:
154
+ name = 'demucs'
155
+ if channels != 64:
156
+ name += str(channels)
157
+ quantizer = None
158
+ if sum([extra, quantized, hq]) > 1:
159
+ raise ValueError("Only one of extra, quantized, hq, can be True.")
160
+ if quantized:
161
+ quantizer = DiffQuantizer(model, group_size=8, min_size=1)
162
+ name += '_quantized'
163
+ if extra:
164
+ name += '_extra'
165
+ if hq:
166
+ name += '_hq'
167
+ _load_state(name, model, quantizer)
168
+ return model
169
+
170
+
171
+ def tasnet(pretrained=True, extra=False):
172
+ if not pretrained and extra:
173
+ raise ValueError("if extra is True, pretrained must be True.")
174
+ model = ConvTasNet(X=10, sources=SOURCES)
175
+ if pretrained:
176
+ name = 'tasnet'
177
+ if extra:
178
+ name = 'tasnet_extra'
179
+ _load_state(name, model)
180
+ return model
demucs/repo.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Represents a model repository, including pre-trained models and bags of models.
7
+ A repo can either be the main remote repository stored in AWS, or a local repository
8
+ with your own models.
9
+ """
10
+
11
+ from hashlib import sha256
12
+ from pathlib import Path
13
+ import typing as tp
14
+
15
+ import torch
16
+ import yaml
17
+
18
+ from .apply import BagOfModels, Model
19
+ from .states import load_model
20
+
21
+
22
+ AnyModel = tp.Union[Model, BagOfModels]
23
+
24
+
25
+ class ModelLoadingError(RuntimeError):
26
+ pass
27
+
28
+
29
+ def check_checksum(path: Path, checksum: str):
30
+ sha = sha256()
31
+ with open(path, 'rb') as file:
32
+ while True:
33
+ buf = file.read(2**20)
34
+ if not buf:
35
+ break
36
+ sha.update(buf)
37
+ actual_checksum = sha.hexdigest()[:len(checksum)]
38
+ if actual_checksum != checksum:
39
+ raise ModelLoadingError(f'Invalid checksum for file {path}, '
40
+ f'expected {checksum} but got {actual_checksum}')
41
+
42
+ class ModelOnlyRepo:
43
+ """Base class for all model only repos.
44
+ """
45
+ def has_model(self, sig: str) -> bool:
46
+ raise NotImplementedError()
47
+
48
+ def get_model(self, sig: str) -> Model:
49
+ raise NotImplementedError()
50
+
51
+
52
+ class RemoteRepo(ModelOnlyRepo):
53
+ def __init__(self, models: tp.Dict[str, str]):
54
+ self._models = models
55
+
56
+ def has_model(self, sig: str) -> bool:
57
+ return sig in self._models
58
+
59
+ def get_model(self, sig: str) -> Model:
60
+ try:
61
+ url = self._models[sig]
62
+ except KeyError:
63
+ raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.')
64
+ pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
65
+ return load_model(pkg)
66
+
67
+
68
+ class LocalRepo(ModelOnlyRepo):
69
+ def __init__(self, root: Path):
70
+ self.root = root
71
+ self.scan()
72
+
73
+ def scan(self):
74
+ self._models = {}
75
+ self._checksums = {}
76
+ for file in self.root.iterdir():
77
+ if file.suffix == '.th':
78
+ if '-' in file.stem:
79
+ xp_sig, checksum = file.stem.split('-')
80
+ self._checksums[xp_sig] = checksum
81
+ else:
82
+ xp_sig = file.stem
83
+ if xp_sig in self._models:
84
+ print('Whats xp? ', xp_sig)
85
+ raise ModelLoadingError(
86
+ f'Duplicate pre-trained model exist for signature {xp_sig}. '
87
+ 'Please delete all but one.')
88
+ self._models[xp_sig] = file
89
+
90
+ def has_model(self, sig: str) -> bool:
91
+ return sig in self._models
92
+
93
+ def get_model(self, sig: str) -> Model:
94
+ try:
95
+ file = self._models[sig]
96
+ except KeyError:
97
+ raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.')
98
+ if sig in self._checksums:
99
+ check_checksum(file, self._checksums[sig])
100
+ return load_model(file)
101
+
102
+
103
+ class BagOnlyRepo:
104
+ """Handles only YAML files containing bag of models, leaving the actual
105
+ model loading to some Repo.
106
+ """
107
+ def __init__(self, root: Path, model_repo: ModelOnlyRepo):
108
+ self.root = root
109
+ self.model_repo = model_repo
110
+ self.scan()
111
+
112
+ def scan(self):
113
+ self._bags = {}
114
+ for file in self.root.iterdir():
115
+ if file.suffix == '.yaml':
116
+ self._bags[file.stem] = file
117
+
118
+ def has_model(self, name: str) -> bool:
119
+ return name in self._bags
120
+
121
+ def get_model(self, name: str) -> BagOfModels:
122
+ try:
123
+ yaml_file = self._bags[name]
124
+ except KeyError:
125
+ raise ModelLoadingError(f'{name} is neither a single pre-trained model or '
126
+ 'a bag of models.')
127
+ bag = yaml.safe_load(open(yaml_file))
128
+ signatures = bag['models']
129
+ models = [self.model_repo.get_model(sig) for sig in signatures]
130
+ weights = bag.get('weights')
131
+ segment = bag.get('segment')
132
+ return BagOfModels(models, weights, segment)
133
+
134
+
135
+ class AnyModelRepo:
136
+ def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
137
+ self.model_repo = model_repo
138
+ self.bag_repo = bag_repo
139
+
140
+ def has_model(self, name_or_sig: str) -> bool:
141
+ return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
142
+
143
+ def get_model(self, name_or_sig: str) -> AnyModel:
144
+ print('name_or_sig: ', name_or_sig)
145
+ if self.model_repo.has_model(name_or_sig):
146
+ return self.model_repo.get_model(name_or_sig)
147
+ else:
148
+ return self.bag_repo.get_model(name_or_sig)
demucs/spec.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Conveniance wrapper to perform STFT and iSTFT"""
7
+
8
+ import torch as th
9
+
10
+
11
+ def spectro(x, n_fft=512, hop_length=None, pad=0):
12
+ *other, length = x.shape
13
+ x = x.reshape(-1, length)
14
+
15
+ device_type = x.device.type
16
+ is_other_gpu = not device_type in ["cuda", "cpu"]
17
+
18
+ if is_other_gpu:
19
+ x = x.cpu()
20
+ z = th.stft(x,
21
+ n_fft * (1 + pad),
22
+ hop_length or n_fft // 4,
23
+ window=th.hann_window(n_fft).to(x),
24
+ win_length=n_fft,
25
+ normalized=True,
26
+ center=True,
27
+ return_complex=True,
28
+ pad_mode='reflect')
29
+ _, freqs, frame = z.shape
30
+ return z.view(*other, freqs, frame)
31
+
32
+
33
+ def ispectro(z, hop_length=None, length=None, pad=0):
34
+ *other, freqs, frames = z.shape
35
+ n_fft = 2 * freqs - 2
36
+ z = z.view(-1, freqs, frames)
37
+ win_length = n_fft // (1 + pad)
38
+
39
+ device_type = z.device.type
40
+ is_other_gpu = not device_type in ["cuda", "cpu"]
41
+
42
+ if is_other_gpu:
43
+ z = z.cpu()
44
+ x = th.istft(z,
45
+ n_fft,
46
+ hop_length,
47
+ window=th.hann_window(win_length).to(z.real),
48
+ win_length=win_length,
49
+ normalized=True,
50
+ length=length,
51
+ center=True)
52
+ _, length = x.shape
53
+ return x.view(*other, length)
demucs/states.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Utilities to save and load models.
8
+ """
9
+ from contextlib import contextmanager
10
+
11
+ import functools
12
+ import hashlib
13
+ import inspect
14
+ import io
15
+ from pathlib import Path
16
+ import warnings
17
+
18
+ from omegaconf import OmegaConf
19
+ from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
20
+ import torch
21
+
22
+
23
+ def get_quantizer(model, args, optimizer=None):
24
+ """Return the quantizer given the XP quantization args."""
25
+ quantizer = None
26
+ if args.diffq:
27
+ quantizer = DiffQuantizer(
28
+ model, min_size=args.min_size, group_size=args.group_size)
29
+ if optimizer is not None:
30
+ quantizer.setup_optimizer(optimizer)
31
+ elif args.qat:
32
+ quantizer = UniformQuantizer(
33
+ model, bits=args.qat, min_size=args.min_size)
34
+ return quantizer
35
+
36
+
37
+ def load_model(path_or_package, strict=False):
38
+ """Load a model from the given serialized model, either given as a dict (already loaded)
39
+ or a path to a file on disk."""
40
+ if isinstance(path_or_package, dict):
41
+ package = path_or_package
42
+ elif isinstance(path_or_package, (str, Path)):
43
+ with warnings.catch_warnings():
44
+ warnings.simplefilter("ignore")
45
+ path = path_or_package
46
+ package = torch.load(path, 'cpu')
47
+ else:
48
+ raise ValueError(f"Invalid type for {path_or_package}.")
49
+
50
+ klass = package["klass"]
51
+ args = package["args"]
52
+ kwargs = package["kwargs"]
53
+
54
+ if strict:
55
+ model = klass(*args, **kwargs)
56
+ else:
57
+ sig = inspect.signature(klass)
58
+ for key in list(kwargs):
59
+ if key not in sig.parameters:
60
+ warnings.warn("Dropping inexistant parameter " + key)
61
+ del kwargs[key]
62
+ model = klass(*args, **kwargs)
63
+
64
+ state = package["state"]
65
+
66
+ set_state(model, state)
67
+ return model
68
+
69
+
70
+ def get_state(model, quantizer, half=False):
71
+ """Get the state from a model, potentially with quantization applied.
72
+ If `half` is True, model are stored as half precision, which shouldn't impact performance
73
+ but half the state size."""
74
+ if quantizer is None:
75
+ dtype = torch.half if half else None
76
+ state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
77
+ else:
78
+ state = quantizer.get_quantized_state()
79
+ state['__quantized'] = True
80
+ return state
81
+
82
+
83
+ def set_state(model, state, quantizer=None):
84
+ """Set the state on a given model."""
85
+ if state.get('__quantized'):
86
+ if quantizer is not None:
87
+ quantizer.restore_quantized_state(model, state['quantized'])
88
+ else:
89
+ restore_quantized_state(model, state)
90
+ else:
91
+ model.load_state_dict(state)
92
+ return state
93
+
94
+
95
+ def save_with_checksum(content, path):
96
+ """Save the given value on disk, along with a sha256 hash.
97
+ Should be used with the output of either `serialize_model` or `get_state`."""
98
+ buf = io.BytesIO()
99
+ torch.save(content, buf)
100
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
101
+
102
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
103
+ path.write_bytes(buf.getvalue())
104
+
105
+
106
+ def serialize_model(model, training_args, quantizer=None, half=True):
107
+ args, kwargs = model._init_args_kwargs
108
+ klass = model.__class__
109
+
110
+ state = get_state(model, quantizer, half)
111
+ return {
112
+ 'klass': klass,
113
+ 'args': args,
114
+ 'kwargs': kwargs,
115
+ 'state': state,
116
+ 'training_args': OmegaConf.to_container(training_args, resolve=True),
117
+ }
118
+
119
+
120
+ def copy_state(state):
121
+ return {k: v.cpu().clone() for k, v in state.items()}
122
+
123
+
124
+ @contextmanager
125
+ def swap_state(model, state):
126
+ """
127
+ Context manager that swaps the state of a model, e.g:
128
+
129
+ # model is in old state
130
+ with swap_state(model, new_state):
131
+ # model in new state
132
+ # model back to old state
133
+ """
134
+ old_state = copy_state(model.state_dict())
135
+ model.load_state_dict(state, strict=False)
136
+ try:
137
+ yield
138
+ finally:
139
+ model.load_state_dict(old_state)
140
+
141
+
142
+ def capture_init(init):
143
+ @functools.wraps(init)
144
+ def __init__(self, *args, **kwargs):
145
+ self._init_args_kwargs = (args, kwargs)
146
+ init(self, *args, **kwargs)
147
+
148
+ return __init__
demucs/tasnet.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Created on 2018/12
8
+ # Author: Kaituo XU
9
+ # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
10
+ # Here is the original license:
11
+ # The MIT License (MIT)
12
+ #
13
+ # Copyright (c) 2018 Kaituo XU
14
+ #
15
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ # of this software and associated documentation files (the "Software"), to deal
17
+ # in the Software without restriction, including without limitation the rights
18
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19
+ # copies of the Software, and to permit persons to whom the Software is
20
+ # furnished to do so, subject to the following conditions:
21
+ #
22
+ # The above copyright notice and this permission notice shall be included in all
23
+ # copies or substantial portions of the Software.
24
+ #
25
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31
+ # SOFTWARE.
32
+
33
+ import math
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ from .utils import capture_init
40
+
41
+ EPS = 1e-8
42
+
43
+
44
+ def overlap_and_add(signal, frame_step):
45
+ outer_dimensions = signal.size()[:-2]
46
+ frames, frame_length = signal.size()[-2:]
47
+
48
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
49
+ subframe_step = frame_step // subframe_length
50
+ subframes_per_frame = frame_length // subframe_length
51
+ output_size = frame_step * (frames - 1) + frame_length
52
+ output_subframes = output_size // subframe_length
53
+
54
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
55
+
56
+ frame = torch.arange(0, output_subframes,
57
+ device=signal.device).unfold(0, subframes_per_frame, subframe_step)
58
+ frame = frame.long() # signal may in GPU or CPU
59
+ frame = frame.contiguous().view(-1)
60
+
61
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
62
+ result.index_add_(-2, frame, subframe_signal)
63
+ result = result.view(*outer_dimensions, -1)
64
+ return result
65
+
66
+
67
+ class ConvTasNet(nn.Module):
68
+ @capture_init
69
+ def __init__(self,
70
+ N=256,
71
+ L=20,
72
+ B=256,
73
+ H=512,
74
+ P=3,
75
+ X=8,
76
+ R=4,
77
+ C=4,
78
+ audio_channels=1,
79
+ samplerate=44100,
80
+ norm_type="gLN",
81
+ causal=False,
82
+ mask_nonlinear='relu'):
83
+ """
84
+ Args:
85
+ N: Number of filters in autoencoder
86
+ L: Length of the filters (in samples)
87
+ B: Number of channels in bottleneck 1 × 1-conv block
88
+ H: Number of channels in convolutional blocks
89
+ P: Kernel size in convolutional blocks
90
+ X: Number of convolutional blocks in each repeat
91
+ R: Number of repeats
92
+ C: Number of speakers
93
+ norm_type: BN, gLN, cLN
94
+ causal: causal or non-causal
95
+ mask_nonlinear: use which non-linear function to generate mask
96
+ """
97
+ super(ConvTasNet, self).__init__()
98
+ # Hyper-parameter
99
+ self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
100
+ self.norm_type = norm_type
101
+ self.causal = causal
102
+ self.mask_nonlinear = mask_nonlinear
103
+ self.audio_channels = audio_channels
104
+ self.samplerate = samplerate
105
+ # Components
106
+ self.encoder = Encoder(L, N, audio_channels)
107
+ self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
108
+ self.decoder = Decoder(N, L, audio_channels)
109
+ # init
110
+ for p in self.parameters():
111
+ if p.dim() > 1:
112
+ nn.init.xavier_normal_(p)
113
+
114
+ def valid_length(self, length):
115
+ return length
116
+
117
+ def forward(self, mixture):
118
+ """
119
+ Args:
120
+ mixture: [M, T], M is batch size, T is #samples
121
+ Returns:
122
+ est_source: [M, C, T]
123
+ """
124
+ mixture_w = self.encoder(mixture)
125
+ est_mask = self.separator(mixture_w)
126
+ est_source = self.decoder(mixture_w, est_mask)
127
+
128
+ # T changed after conv1d in encoder, fix it here
129
+ T_origin = mixture.size(-1)
130
+ T_conv = est_source.size(-1)
131
+ est_source = F.pad(est_source, (0, T_origin - T_conv))
132
+ return est_source
133
+
134
+
135
+ class Encoder(nn.Module):
136
+ """Estimation of the nonnegative mixture weight by a 1-D conv layer.
137
+ """
138
+ def __init__(self, L, N, audio_channels):
139
+ super(Encoder, self).__init__()
140
+ # Hyper-parameter
141
+ self.L, self.N = L, N
142
+ # Components
143
+ # 50% overlap
144
+ self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
145
+
146
+ def forward(self, mixture):
147
+ """
148
+ Args:
149
+ mixture: [M, T], M is batch size, T is #samples
150
+ Returns:
151
+ mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
152
+ """
153
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
154
+ return mixture_w
155
+
156
+
157
+ class Decoder(nn.Module):
158
+ def __init__(self, N, L, audio_channels):
159
+ super(Decoder, self).__init__()
160
+ # Hyper-parameter
161
+ self.N, self.L = N, L
162
+ self.audio_channels = audio_channels
163
+ # Components
164
+ self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
165
+
166
+ def forward(self, mixture_w, est_mask):
167
+ """
168
+ Args:
169
+ mixture_w: [M, N, K]
170
+ est_mask: [M, C, N, K]
171
+ Returns:
172
+ est_source: [M, C, T]
173
+ """
174
+ # D = W * M
175
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
176
+ source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
177
+ # S = DV
178
+ est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
179
+ m, c, k, _ = est_source.size()
180
+ est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
181
+ est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
182
+ return est_source
183
+
184
+
185
+ class TemporalConvNet(nn.Module):
186
+ def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
187
+ """
188
+ Args:
189
+ N: Number of filters in autoencoder
190
+ B: Number of channels in bottleneck 1 × 1-conv block
191
+ H: Number of channels in convolutional blocks
192
+ P: Kernel size in convolutional blocks
193
+ X: Number of convolutional blocks in each repeat
194
+ R: Number of repeats
195
+ C: Number of speakers
196
+ norm_type: BN, gLN, cLN
197
+ causal: causal or non-causal
198
+ mask_nonlinear: use which non-linear function to generate mask
199
+ """
200
+ super(TemporalConvNet, self).__init__()
201
+ # Hyper-parameter
202
+ self.C = C
203
+ self.mask_nonlinear = mask_nonlinear
204
+ # Components
205
+ # [M, N, K] -> [M, N, K]
206
+ layer_norm = ChannelwiseLayerNorm(N)
207
+ # [M, N, K] -> [M, B, K]
208
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
209
+ # [M, B, K] -> [M, B, K]
210
+ repeats = []
211
+ for r in range(R):
212
+ blocks = []
213
+ for x in range(X):
214
+ dilation = 2**x
215
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
216
+ blocks += [
217
+ TemporalBlock(B,
218
+ H,
219
+ P,
220
+ stride=1,
221
+ padding=padding,
222
+ dilation=dilation,
223
+ norm_type=norm_type,
224
+ causal=causal)
225
+ ]
226
+ repeats += [nn.Sequential(*blocks)]
227
+ temporal_conv_net = nn.Sequential(*repeats)
228
+ # [M, B, K] -> [M, C*N, K]
229
+ mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
230
+ # Put together
231
+ self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
232
+ mask_conv1x1)
233
+
234
+ def forward(self, mixture_w):
235
+ """
236
+ Keep this API same with TasNet
237
+ Args:
238
+ mixture_w: [M, N, K], M is batch size
239
+ returns:
240
+ est_mask: [M, C, N, K]
241
+ """
242
+ M, N, K = mixture_w.size()
243
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
244
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
245
+ if self.mask_nonlinear == 'softmax':
246
+ est_mask = F.softmax(score, dim=1)
247
+ elif self.mask_nonlinear == 'relu':
248
+ est_mask = F.relu(score)
249
+ else:
250
+ raise ValueError("Unsupported mask non-linear function")
251
+ return est_mask
252
+
253
+
254
+ class TemporalBlock(nn.Module):
255
+ def __init__(self,
256
+ in_channels,
257
+ out_channels,
258
+ kernel_size,
259
+ stride,
260
+ padding,
261
+ dilation,
262
+ norm_type="gLN",
263
+ causal=False):
264
+ super(TemporalBlock, self).__init__()
265
+ # [M, B, K] -> [M, H, K]
266
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
267
+ prelu = nn.PReLU()
268
+ norm = chose_norm(norm_type, out_channels)
269
+ # [M, H, K] -> [M, B, K]
270
+ dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
271
+ dilation, norm_type, causal)
272
+ # Put together
273
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
274
+
275
+ def forward(self, x):
276
+ """
277
+ Args:
278
+ x: [M, B, K]
279
+ Returns:
280
+ [M, B, K]
281
+ """
282
+ residual = x
283
+ out = self.net(x)
284
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
285
+ return out + residual # look like w/o F.relu is better than w/ F.relu
286
+ # return F.relu(out + residual)
287
+
288
+
289
+ class DepthwiseSeparableConv(nn.Module):
290
+ def __init__(self,
291
+ in_channels,
292
+ out_channels,
293
+ kernel_size,
294
+ stride,
295
+ padding,
296
+ dilation,
297
+ norm_type="gLN",
298
+ causal=False):
299
+ super(DepthwiseSeparableConv, self).__init__()
300
+ # Use `groups` option to implement depthwise convolution
301
+ # [M, H, K] -> [M, H, K]
302
+ depthwise_conv = nn.Conv1d(in_channels,
303
+ in_channels,
304
+ kernel_size,
305
+ stride=stride,
306
+ padding=padding,
307
+ dilation=dilation,
308
+ groups=in_channels,
309
+ bias=False)
310
+ if causal:
311
+ chomp = Chomp1d(padding)
312
+ prelu = nn.PReLU()
313
+ norm = chose_norm(norm_type, in_channels)
314
+ # [M, H, K] -> [M, B, K]
315
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
316
+ # Put together
317
+ if causal:
318
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
319
+ else:
320
+ self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
321
+
322
+ def forward(self, x):
323
+ """
324
+ Args:
325
+ x: [M, H, K]
326
+ Returns:
327
+ result: [M, B, K]
328
+ """
329
+ return self.net(x)
330
+
331
+
332
+ class Chomp1d(nn.Module):
333
+ """To ensure the output length is the same as the input.
334
+ """
335
+ def __init__(self, chomp_size):
336
+ super(Chomp1d, self).__init__()
337
+ self.chomp_size = chomp_size
338
+
339
+ def forward(self, x):
340
+ """
341
+ Args:
342
+ x: [M, H, Kpad]
343
+ Returns:
344
+ [M, H, K]
345
+ """
346
+ return x[:, :, :-self.chomp_size].contiguous()
347
+
348
+
349
+ def chose_norm(norm_type, channel_size):
350
+ """The input of normlization will be (M, C, K), where M is batch size,
351
+ C is channel size and K is sequence length.
352
+ """
353
+ if norm_type == "gLN":
354
+ return GlobalLayerNorm(channel_size)
355
+ elif norm_type == "cLN":
356
+ return ChannelwiseLayerNorm(channel_size)
357
+ elif norm_type == "id":
358
+ return nn.Identity()
359
+ else: # norm_type == "BN":
360
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
361
+ # along M and K, so this BN usage is right.
362
+ return nn.BatchNorm1d(channel_size)
363
+
364
+
365
+ # TODO: Use nn.LayerNorm to impl cLN to speed up
366
+ class ChannelwiseLayerNorm(nn.Module):
367
+ """Channel-wise Layer Normalization (cLN)"""
368
+ def __init__(self, channel_size):
369
+ super(ChannelwiseLayerNorm, self).__init__()
370
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
371
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
372
+ self.reset_parameters()
373
+
374
+ def reset_parameters(self):
375
+ self.gamma.data.fill_(1)
376
+ self.beta.data.zero_()
377
+
378
+ def forward(self, y):
379
+ """
380
+ Args:
381
+ y: [M, N, K], M is batch size, N is channel size, K is length
382
+ Returns:
383
+ cLN_y: [M, N, K]
384
+ """
385
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
386
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
387
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
388
+ return cLN_y
389
+
390
+
391
+ class GlobalLayerNorm(nn.Module):
392
+ """Global Layer Normalization (gLN)"""
393
+ def __init__(self, channel_size):
394
+ super(GlobalLayerNorm, self).__init__()
395
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
396
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
397
+ self.reset_parameters()
398
+
399
+ def reset_parameters(self):
400
+ self.gamma.data.fill_(1)
401
+ self.beta.data.zero_()
402
+
403
+ def forward(self, y):
404
+ """
405
+ Args:
406
+ y: [M, N, K], M is batch size, N is channel size, K is length
407
+ Returns:
408
+ gLN_y: [M, N, K]
409
+ """
410
+ # TODO: in torch 1.0, torch.mean() support dim list
411
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
412
+ var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
413
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
414
+ return gLN_y
415
+
416
+
417
+ if __name__ == "__main__":
418
+ torch.manual_seed(123)
419
+ M, N, L, T = 2, 3, 4, 12
420
+ K = 2 * T // L - 1
421
+ B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
422
+ mixture = torch.randint(3, (M, T))
423
+ # test Encoder
424
+ encoder = Encoder(L, N)
425
+ encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
426
+ mixture_w = encoder(mixture)
427
+ print('mixture', mixture)
428
+ print('U', encoder.conv1d_U.weight)
429
+ print('mixture_w', mixture_w)
430
+ print('mixture_w size', mixture_w.size())
431
+
432
+ # test TemporalConvNet
433
+ separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
434
+ est_mask = separator(mixture_w)
435
+ print('est_mask', est_mask)
436
+
437
+ # test Decoder
438
+ decoder = Decoder(N, L)
439
+ est_mask = torch.randint(2, (B, K, C, N))
440
+ est_source = decoder(mixture_w, est_mask)
441
+ print('est_source', est_source)
442
+
443
+ # test Conv-TasNet
444
+ conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
445
+ est_source = conv_tasnet(mixture)
446
+ print('est_source', est_source)
447
+ print('est_source size', est_source.size())
demucs/tasnet_v2.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Created on 2018/12
8
+ # Author: Kaituo XU
9
+ # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
10
+ # Here is the original license:
11
+ # The MIT License (MIT)
12
+ #
13
+ # Copyright (c) 2018 Kaituo XU
14
+ #
15
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ # of this software and associated documentation files (the "Software"), to deal
17
+ # in the Software without restriction, including without limitation the rights
18
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19
+ # copies of the Software, and to permit persons to whom the Software is
20
+ # furnished to do so, subject to the following conditions:
21
+ #
22
+ # The above copyright notice and this permission notice shall be included in all
23
+ # copies or substantial portions of the Software.
24
+ #
25
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31
+ # SOFTWARE.
32
+
33
+ import math
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ from .utils import capture_init
40
+
41
+ EPS = 1e-8
42
+
43
+
44
+ def overlap_and_add(signal, frame_step):
45
+ outer_dimensions = signal.size()[:-2]
46
+ frames, frame_length = signal.size()[-2:]
47
+
48
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
49
+ subframe_step = frame_step // subframe_length
50
+ subframes_per_frame = frame_length // subframe_length
51
+ output_size = frame_step * (frames - 1) + frame_length
52
+ output_subframes = output_size // subframe_length
53
+
54
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
55
+
56
+ frame = torch.arange(0, output_subframes,
57
+ device=signal.device).unfold(0, subframes_per_frame, subframe_step)
58
+ frame = frame.long() # signal may in GPU or CPU
59
+ frame = frame.contiguous().view(-1)
60
+
61
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
62
+ result.index_add_(-2, frame, subframe_signal)
63
+ result = result.view(*outer_dimensions, -1)
64
+ return result
65
+
66
+
67
+ class ConvTasNet(nn.Module):
68
+ @capture_init
69
+ def __init__(self,
70
+ sources,
71
+ N=256,
72
+ L=20,
73
+ B=256,
74
+ H=512,
75
+ P=3,
76
+ X=8,
77
+ R=4,
78
+ audio_channels=2,
79
+ norm_type="gLN",
80
+ causal=False,
81
+ mask_nonlinear='relu',
82
+ samplerate=44100,
83
+ segment_length=44100 * 2 * 4):
84
+ """
85
+ Args:
86
+ sources: list of sources
87
+ N: Number of filters in autoencoder
88
+ L: Length of the filters (in samples)
89
+ B: Number of channels in bottleneck 1 × 1-conv block
90
+ H: Number of channels in convolutional blocks
91
+ P: Kernel size in convolutional blocks
92
+ X: Number of convolutional blocks in each repeat
93
+ R: Number of repeats
94
+ norm_type: BN, gLN, cLN
95
+ causal: causal or non-causal
96
+ mask_nonlinear: use which non-linear function to generate mask
97
+ """
98
+ super(ConvTasNet, self).__init__()
99
+ # Hyper-parameter
100
+ self.sources = sources
101
+ self.C = len(sources)
102
+ self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
103
+ self.norm_type = norm_type
104
+ self.causal = causal
105
+ self.mask_nonlinear = mask_nonlinear
106
+ self.audio_channels = audio_channels
107
+ self.samplerate = samplerate
108
+ self.segment_length = segment_length
109
+ # Components
110
+ self.encoder = Encoder(L, N, audio_channels)
111
+ self.separator = TemporalConvNet(
112
+ N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
113
+ self.decoder = Decoder(N, L, audio_channels)
114
+ # init
115
+ for p in self.parameters():
116
+ if p.dim() > 1:
117
+ nn.init.xavier_normal_(p)
118
+
119
+ def valid_length(self, length):
120
+ return length
121
+
122
+ def forward(self, mixture):
123
+ """
124
+ Args:
125
+ mixture: [M, T], M is batch size, T is #samples
126
+ Returns:
127
+ est_source: [M, C, T]
128
+ """
129
+ mixture_w = self.encoder(mixture)
130
+ est_mask = self.separator(mixture_w)
131
+ est_source = self.decoder(mixture_w, est_mask)
132
+
133
+ # T changed after conv1d in encoder, fix it here
134
+ T_origin = mixture.size(-1)
135
+ T_conv = est_source.size(-1)
136
+ est_source = F.pad(est_source, (0, T_origin - T_conv))
137
+ return est_source
138
+
139
+
140
+ class Encoder(nn.Module):
141
+ """Estimation of the nonnegative mixture weight by a 1-D conv layer.
142
+ """
143
+ def __init__(self, L, N, audio_channels):
144
+ super(Encoder, self).__init__()
145
+ # Hyper-parameter
146
+ self.L, self.N = L, N
147
+ # Components
148
+ # 50% overlap
149
+ self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
150
+
151
+ def forward(self, mixture):
152
+ """
153
+ Args:
154
+ mixture: [M, T], M is batch size, T is #samples
155
+ Returns:
156
+ mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
157
+ """
158
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
159
+ return mixture_w
160
+
161
+
162
+ class Decoder(nn.Module):
163
+ def __init__(self, N, L, audio_channels):
164
+ super(Decoder, self).__init__()
165
+ # Hyper-parameter
166
+ self.N, self.L = N, L
167
+ self.audio_channels = audio_channels
168
+ # Components
169
+ self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
170
+
171
+ def forward(self, mixture_w, est_mask):
172
+ """
173
+ Args:
174
+ mixture_w: [M, N, K]
175
+ est_mask: [M, C, N, K]
176
+ Returns:
177
+ est_source: [M, C, T]
178
+ """
179
+ # D = W * M
180
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
181
+ source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
182
+ # S = DV
183
+ est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
184
+ m, c, k, _ = est_source.size()
185
+ est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
186
+ est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
187
+ return est_source
188
+
189
+
190
+ class TemporalConvNet(nn.Module):
191
+ def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
192
+ """
193
+ Args:
194
+ N: Number of filters in autoencoder
195
+ B: Number of channels in bottleneck 1 × 1-conv block
196
+ H: Number of channels in convolutional blocks
197
+ P: Kernel size in convolutional blocks
198
+ X: Number of convolutional blocks in each repeat
199
+ R: Number of repeats
200
+ C: Number of speakers
201
+ norm_type: BN, gLN, cLN
202
+ causal: causal or non-causal
203
+ mask_nonlinear: use which non-linear function to generate mask
204
+ """
205
+ super(TemporalConvNet, self).__init__()
206
+ # Hyper-parameter
207
+ self.C = C
208
+ self.mask_nonlinear = mask_nonlinear
209
+ # Components
210
+ # [M, N, K] -> [M, N, K]
211
+ layer_norm = ChannelwiseLayerNorm(N)
212
+ # [M, N, K] -> [M, B, K]
213
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
214
+ # [M, B, K] -> [M, B, K]
215
+ repeats = []
216
+ for r in range(R):
217
+ blocks = []
218
+ for x in range(X):
219
+ dilation = 2**x
220
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
221
+ blocks += [
222
+ TemporalBlock(B,
223
+ H,
224
+ P,
225
+ stride=1,
226
+ padding=padding,
227
+ dilation=dilation,
228
+ norm_type=norm_type,
229
+ causal=causal)
230
+ ]
231
+ repeats += [nn.Sequential(*blocks)]
232
+ temporal_conv_net = nn.Sequential(*repeats)
233
+ # [M, B, K] -> [M, C*N, K]
234
+ mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
235
+ # Put together
236
+ self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
237
+ mask_conv1x1)
238
+
239
+ def forward(self, mixture_w):
240
+ """
241
+ Keep this API same with TasNet
242
+ Args:
243
+ mixture_w: [M, N, K], M is batch size
244
+ returns:
245
+ est_mask: [M, C, N, K]
246
+ """
247
+ M, N, K = mixture_w.size()
248
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
249
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
250
+ if self.mask_nonlinear == 'softmax':
251
+ est_mask = F.softmax(score, dim=1)
252
+ elif self.mask_nonlinear == 'relu':
253
+ est_mask = F.relu(score)
254
+ else:
255
+ raise ValueError("Unsupported mask non-linear function")
256
+ return est_mask
257
+
258
+
259
+ class TemporalBlock(nn.Module):
260
+ def __init__(self,
261
+ in_channels,
262
+ out_channels,
263
+ kernel_size,
264
+ stride,
265
+ padding,
266
+ dilation,
267
+ norm_type="gLN",
268
+ causal=False):
269
+ super(TemporalBlock, self).__init__()
270
+ # [M, B, K] -> [M, H, K]
271
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
272
+ prelu = nn.PReLU()
273
+ norm = chose_norm(norm_type, out_channels)
274
+ # [M, H, K] -> [M, B, K]
275
+ dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
276
+ dilation, norm_type, causal)
277
+ # Put together
278
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
279
+
280
+ def forward(self, x):
281
+ """
282
+ Args:
283
+ x: [M, B, K]
284
+ Returns:
285
+ [M, B, K]
286
+ """
287
+ residual = x
288
+ out = self.net(x)
289
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
290
+ return out + residual # look like w/o F.relu is better than w/ F.relu
291
+ # return F.relu(out + residual)
292
+
293
+
294
+ class DepthwiseSeparableConv(nn.Module):
295
+ def __init__(self,
296
+ in_channels,
297
+ out_channels,
298
+ kernel_size,
299
+ stride,
300
+ padding,
301
+ dilation,
302
+ norm_type="gLN",
303
+ causal=False):
304
+ super(DepthwiseSeparableConv, self).__init__()
305
+ # Use `groups` option to implement depthwise convolution
306
+ # [M, H, K] -> [M, H, K]
307
+ depthwise_conv = nn.Conv1d(in_channels,
308
+ in_channels,
309
+ kernel_size,
310
+ stride=stride,
311
+ padding=padding,
312
+ dilation=dilation,
313
+ groups=in_channels,
314
+ bias=False)
315
+ if causal:
316
+ chomp = Chomp1d(padding)
317
+ prelu = nn.PReLU()
318
+ norm = chose_norm(norm_type, in_channels)
319
+ # [M, H, K] -> [M, B, K]
320
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
321
+ # Put together
322
+ if causal:
323
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
324
+ else:
325
+ self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
326
+
327
+ def forward(self, x):
328
+ """
329
+ Args:
330
+ x: [M, H, K]
331
+ Returns:
332
+ result: [M, B, K]
333
+ """
334
+ return self.net(x)
335
+
336
+
337
+ class Chomp1d(nn.Module):
338
+ """To ensure the output length is the same as the input.
339
+ """
340
+ def __init__(self, chomp_size):
341
+ super(Chomp1d, self).__init__()
342
+ self.chomp_size = chomp_size
343
+
344
+ def forward(self, x):
345
+ """
346
+ Args:
347
+ x: [M, H, Kpad]
348
+ Returns:
349
+ [M, H, K]
350
+ """
351
+ return x[:, :, :-self.chomp_size].contiguous()
352
+
353
+
354
+ def chose_norm(norm_type, channel_size):
355
+ """The input of normlization will be (M, C, K), where M is batch size,
356
+ C is channel size and K is sequence length.
357
+ """
358
+ if norm_type == "gLN":
359
+ return GlobalLayerNorm(channel_size)
360
+ elif norm_type == "cLN":
361
+ return ChannelwiseLayerNorm(channel_size)
362
+ elif norm_type == "id":
363
+ return nn.Identity()
364
+ else: # norm_type == "BN":
365
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
366
+ # along M and K, so this BN usage is right.
367
+ return nn.BatchNorm1d(channel_size)
368
+
369
+
370
+ # TODO: Use nn.LayerNorm to impl cLN to speed up
371
+ class ChannelwiseLayerNorm(nn.Module):
372
+ """Channel-wise Layer Normalization (cLN)"""
373
+ def __init__(self, channel_size):
374
+ super(ChannelwiseLayerNorm, self).__init__()
375
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
376
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
377
+ self.reset_parameters()
378
+
379
+ def reset_parameters(self):
380
+ self.gamma.data.fill_(1)
381
+ self.beta.data.zero_()
382
+
383
+ def forward(self, y):
384
+ """
385
+ Args:
386
+ y: [M, N, K], M is batch size, N is channel size, K is length
387
+ Returns:
388
+ cLN_y: [M, N, K]
389
+ """
390
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
391
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
392
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
393
+ return cLN_y
394
+
395
+
396
+ class GlobalLayerNorm(nn.Module):
397
+ """Global Layer Normalization (gLN)"""
398
+ def __init__(self, channel_size):
399
+ super(GlobalLayerNorm, self).__init__()
400
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
401
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
402
+ self.reset_parameters()
403
+
404
+ def reset_parameters(self):
405
+ self.gamma.data.fill_(1)
406
+ self.beta.data.zero_()
407
+
408
+ def forward(self, y):
409
+ """
410
+ Args:
411
+ y: [M, N, K], M is batch size, N is channel size, K is length
412
+ Returns:
413
+ gLN_y: [M, N, K]
414
+ """
415
+ # TODO: in torch 1.0, torch.mean() support dim list
416
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
417
+ var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
418
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
419
+ return gLN_y
420
+
421
+
422
+ if __name__ == "__main__":
423
+ torch.manual_seed(123)
424
+ M, N, L, T = 2, 3, 4, 12
425
+ K = 2 * T // L - 1
426
+ B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
427
+ mixture = torch.randint(3, (M, T))
428
+ # test Encoder
429
+ encoder = Encoder(L, N)
430
+ encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
431
+ mixture_w = encoder(mixture)
432
+ print('mixture', mixture)
433
+ print('U', encoder.conv1d_U.weight)
434
+ print('mixture_w', mixture_w)
435
+ print('mixture_w size', mixture_w.size())
436
+
437
+ # test TemporalConvNet
438
+ separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
439
+ est_mask = separator(mixture_w)
440
+ print('est_mask', est_mask)
441
+
442
+ # test Decoder
443
+ decoder = Decoder(N, L)
444
+ est_mask = torch.randint(2, (B, K, C, N))
445
+ est_source = decoder(mixture_w, est_mask)
446
+ print('est_source', est_source)
447
+
448
+ # test Conv-TasNet
449
+ conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
450
+ est_source = conv_tasnet(mixture)
451
+ print('est_source', est_source)
452
+ print('est_source size', est_source.size())
demucs/transformer.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Meta, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+
8
+ import random
9
+ import typing as tp
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import math
16
+ from einops import rearrange
17
+
18
+
19
+ def create_sin_embedding(
20
+ length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
21
+ ):
22
+ # We aim for TBC format
23
+ assert dim % 2 == 0
24
+ pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
25
+ half_dim = dim // 2
26
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
27
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
28
+ return torch.cat(
29
+ [
30
+ torch.cos(phase),
31
+ torch.sin(phase),
32
+ ],
33
+ dim=-1,
34
+ )
35
+
36
+
37
+ def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
38
+ """
39
+ :param d_model: dimension of the model
40
+ :param height: height of the positions
41
+ :param width: width of the positions
42
+ :return: d_model*height*width position matrix
43
+ """
44
+ if d_model % 4 != 0:
45
+ raise ValueError(
46
+ "Cannot use sin/cos positional encoding with "
47
+ "odd dimension (got dim={:d})".format(d_model)
48
+ )
49
+ pe = torch.zeros(d_model, height, width)
50
+ # Each dimension use half of d_model
51
+ d_model = int(d_model / 2)
52
+ div_term = torch.exp(
53
+ torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
54
+ )
55
+ pos_w = torch.arange(0.0, width).unsqueeze(1)
56
+ pos_h = torch.arange(0.0, height).unsqueeze(1)
57
+ pe[0:d_model:2, :, :] = (
58
+ torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
59
+ )
60
+ pe[1:d_model:2, :, :] = (
61
+ torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
62
+ )
63
+ pe[d_model::2, :, :] = (
64
+ torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
65
+ )
66
+ pe[d_model + 1:: 2, :, :] = (
67
+ torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
68
+ )
69
+
70
+ return pe[None, :].to(device)
71
+
72
+
73
+ def create_sin_embedding_cape(
74
+ length: int,
75
+ dim: int,
76
+ batch_size: int,
77
+ mean_normalize: bool,
78
+ augment: bool, # True during training
79
+ max_global_shift: float = 0.0, # delta max
80
+ max_local_shift: float = 0.0, # epsilon max
81
+ max_scale: float = 1.0,
82
+ device: str = "cpu",
83
+ max_period: float = 10000.0,
84
+ ):
85
+ # We aim for TBC format
86
+ assert dim % 2 == 0
87
+ pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
88
+ pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
89
+ if mean_normalize:
90
+ pos -= torch.nanmean(pos, dim=0, keepdim=True)
91
+
92
+ if augment:
93
+ delta = np.random.uniform(
94
+ -max_global_shift, +max_global_shift, size=[1, batch_size, 1]
95
+ )
96
+ delta_local = np.random.uniform(
97
+ -max_local_shift, +max_local_shift, size=[length, batch_size, 1]
98
+ )
99
+ log_lambdas = np.random.uniform(
100
+ -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
101
+ )
102
+ pos = (pos + delta + delta_local) * np.exp(log_lambdas)
103
+
104
+ pos = pos.to(device)
105
+
106
+ half_dim = dim // 2
107
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
108
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
109
+ return torch.cat(
110
+ [
111
+ torch.cos(phase),
112
+ torch.sin(phase),
113
+ ],
114
+ dim=-1,
115
+ ).float()
116
+
117
+
118
+ def get_causal_mask(length):
119
+ pos = torch.arange(length)
120
+ return pos > pos[:, None]
121
+
122
+
123
+ def get_elementary_mask(
124
+ T1,
125
+ T2,
126
+ mask_type,
127
+ sparse_attn_window,
128
+ global_window,
129
+ mask_random_seed,
130
+ sparsity,
131
+ device,
132
+ ):
133
+ """
134
+ When the input of the Decoder has length T1 and the output T2
135
+ The mask matrix has shape (T2, T1)
136
+ """
137
+ assert mask_type in ["diag", "jmask", "random", "global"]
138
+
139
+ if mask_type == "global":
140
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
141
+ mask[:, :global_window] = True
142
+ line_window = int(global_window * T2 / T1)
143
+ mask[:line_window, :] = True
144
+
145
+ if mask_type == "diag":
146
+
147
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
148
+ rows = torch.arange(T2)[:, None]
149
+ cols = (
150
+ (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
151
+ .long()
152
+ .clamp(0, T1 - 1)
153
+ )
154
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
155
+
156
+ elif mask_type == "jmask":
157
+ mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
158
+ rows = torch.arange(T2 + 2)[:, None]
159
+ t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
160
+ t = (t * (t + 1) / 2).int()
161
+ t = torch.cat([-t.flip(0)[:-1], t])
162
+ cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
163
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
164
+ mask = mask[1:-1, 1:-1]
165
+
166
+ elif mask_type == "random":
167
+ gene = torch.Generator(device=device)
168
+ gene.manual_seed(mask_random_seed)
169
+ mask = (
170
+ torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1)
171
+ > sparsity
172
+ )
173
+
174
+ mask = mask.to(device)
175
+ return mask
176
+
177
+
178
+ def get_mask(
179
+ T1,
180
+ T2,
181
+ mask_type,
182
+ sparse_attn_window,
183
+ global_window,
184
+ mask_random_seed,
185
+ sparsity,
186
+ device,
187
+ ):
188
+ """
189
+ Return a SparseCSRTensor mask that is a combination of elementary masks
190
+ mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
191
+ """
192
+ from xformers.sparse import SparseCSRTensor
193
+ # create a list
194
+ mask_types = mask_type.split("_")
195
+
196
+ all_masks = [
197
+ get_elementary_mask(
198
+ T1,
199
+ T2,
200
+ mask,
201
+ sparse_attn_window,
202
+ global_window,
203
+ mask_random_seed,
204
+ sparsity,
205
+ device,
206
+ )
207
+ for mask in mask_types
208
+ ]
209
+
210
+ final_mask = torch.stack(all_masks).sum(axis=0) > 0
211
+
212
+ return SparseCSRTensor.from_dense(final_mask[None])
213
+
214
+
215
+ class ScaledEmbedding(nn.Module):
216
+ def __init__(
217
+ self,
218
+ num_embeddings: int,
219
+ embedding_dim: int,
220
+ scale: float = 1.0,
221
+ boost: float = 3.0,
222
+ ):
223
+ super().__init__()
224
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
225
+ self.embedding.weight.data *= scale / boost
226
+ self.boost = boost
227
+
228
+ @property
229
+ def weight(self):
230
+ return self.embedding.weight * self.boost
231
+
232
+ def forward(self, x):
233
+ return self.embedding(x) * self.boost
234
+
235
+
236
+ class LayerScale(nn.Module):
237
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
238
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
239
+ """
240
+
241
+ def __init__(self, channels: int, init: float = 0, channel_last=False):
242
+ """
243
+ channel_last = False corresponds to (B, C, T) tensors
244
+ channel_last = True corresponds to (T, B, C) tensors
245
+ """
246
+ super().__init__()
247
+ self.channel_last = channel_last
248
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
249
+ self.scale.data[:] = init
250
+
251
+ def forward(self, x):
252
+ if self.channel_last:
253
+ return self.scale * x
254
+ else:
255
+ return self.scale[:, None] * x
256
+
257
+
258
+ class MyGroupNorm(nn.GroupNorm):
259
+ def __init__(self, *args, **kwargs):
260
+ super().__init__(*args, **kwargs)
261
+
262
+ def forward(self, x):
263
+ """
264
+ x: (B, T, C)
265
+ if num_groups=1: Normalisation on all T and C together for each B
266
+ """
267
+ x = x.transpose(1, 2)
268
+ return super().forward(x).transpose(1, 2)
269
+
270
+
271
+ class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
272
+ def __init__(
273
+ self,
274
+ d_model,
275
+ nhead,
276
+ dim_feedforward=2048,
277
+ dropout=0.1,
278
+ activation=F.relu,
279
+ group_norm=0,
280
+ norm_first=False,
281
+ norm_out=False,
282
+ layer_norm_eps=1e-5,
283
+ layer_scale=False,
284
+ init_values=1e-4,
285
+ device=None,
286
+ dtype=None,
287
+ sparse=False,
288
+ mask_type="diag",
289
+ mask_random_seed=42,
290
+ sparse_attn_window=500,
291
+ global_window=50,
292
+ auto_sparsity=False,
293
+ sparsity=0.95,
294
+ batch_first=False,
295
+ ):
296
+ factory_kwargs = {"device": device, "dtype": dtype}
297
+ super().__init__(
298
+ d_model=d_model,
299
+ nhead=nhead,
300
+ dim_feedforward=dim_feedforward,
301
+ dropout=dropout,
302
+ activation=activation,
303
+ layer_norm_eps=layer_norm_eps,
304
+ batch_first=batch_first,
305
+ norm_first=norm_first,
306
+ device=device,
307
+ dtype=dtype,
308
+ )
309
+ self.sparse = sparse
310
+ self.auto_sparsity = auto_sparsity
311
+ if sparse:
312
+ if not auto_sparsity:
313
+ self.mask_type = mask_type
314
+ self.sparse_attn_window = sparse_attn_window
315
+ self.global_window = global_window
316
+ self.sparsity = sparsity
317
+ if group_norm:
318
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
319
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
320
+
321
+ self.norm_out = None
322
+ if self.norm_first & norm_out:
323
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
324
+ self.gamma_1 = (
325
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
326
+ )
327
+ self.gamma_2 = (
328
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
329
+ )
330
+
331
+ if sparse:
332
+ self.self_attn = MultiheadAttention(
333
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
334
+ auto_sparsity=sparsity if auto_sparsity else 0,
335
+ )
336
+ self.__setattr__("src_mask", torch.zeros(1, 1))
337
+ self.mask_random_seed = mask_random_seed
338
+
339
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
340
+ """
341
+ if batch_first = False, src shape is (T, B, C)
342
+ the case where batch_first=True is not covered
343
+ """
344
+ device = src.device
345
+ x = src
346
+ T, B, C = x.shape
347
+ if self.sparse and not self.auto_sparsity:
348
+ assert src_mask is None
349
+ src_mask = self.src_mask
350
+ if src_mask.shape[-1] != T:
351
+ src_mask = get_mask(
352
+ T,
353
+ T,
354
+ self.mask_type,
355
+ self.sparse_attn_window,
356
+ self.global_window,
357
+ self.mask_random_seed,
358
+ self.sparsity,
359
+ device,
360
+ )
361
+ self.__setattr__("src_mask", src_mask)
362
+
363
+ if self.norm_first:
364
+ x = x + self.gamma_1(
365
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
366
+ )
367
+ x = x + self.gamma_2(self._ff_block(self.norm2(x)))
368
+
369
+ if self.norm_out:
370
+ x = self.norm_out(x)
371
+ else:
372
+ x = self.norm1(
373
+ x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
374
+ )
375
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
376
+
377
+ return x
378
+
379
+
380
+ class CrossTransformerEncoderLayer(nn.Module):
381
+ def __init__(
382
+ self,
383
+ d_model: int,
384
+ nhead: int,
385
+ dim_feedforward: int = 2048,
386
+ dropout: float = 0.1,
387
+ activation=F.relu,
388
+ layer_norm_eps: float = 1e-5,
389
+ layer_scale: bool = False,
390
+ init_values: float = 1e-4,
391
+ norm_first: bool = False,
392
+ group_norm: bool = False,
393
+ norm_out: bool = False,
394
+ sparse=False,
395
+ mask_type="diag",
396
+ mask_random_seed=42,
397
+ sparse_attn_window=500,
398
+ global_window=50,
399
+ sparsity=0.95,
400
+ auto_sparsity=None,
401
+ device=None,
402
+ dtype=None,
403
+ batch_first=False,
404
+ ):
405
+ factory_kwargs = {"device": device, "dtype": dtype}
406
+ super().__init__()
407
+
408
+ self.sparse = sparse
409
+ self.auto_sparsity = auto_sparsity
410
+ if sparse:
411
+ if not auto_sparsity:
412
+ self.mask_type = mask_type
413
+ self.sparse_attn_window = sparse_attn_window
414
+ self.global_window = global_window
415
+ self.sparsity = sparsity
416
+
417
+ self.cross_attn: nn.Module
418
+ self.cross_attn = nn.MultiheadAttention(
419
+ d_model, nhead, dropout=dropout, batch_first=batch_first)
420
+ # Implementation of Feedforward model
421
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
422
+ self.dropout = nn.Dropout(dropout)
423
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
424
+
425
+ self.norm_first = norm_first
426
+ self.norm1: nn.Module
427
+ self.norm2: nn.Module
428
+ self.norm3: nn.Module
429
+ if group_norm:
430
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
431
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
432
+ self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
435
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
436
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
437
+
438
+ self.norm_out = None
439
+ if self.norm_first & norm_out:
440
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
441
+
442
+ self.gamma_1 = (
443
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
444
+ )
445
+ self.gamma_2 = (
446
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
447
+ )
448
+
449
+ self.dropout1 = nn.Dropout(dropout)
450
+ self.dropout2 = nn.Dropout(dropout)
451
+
452
+ # Legacy string support for activation function.
453
+ if isinstance(activation, str):
454
+ self.activation = self._get_activation_fn(activation)
455
+ else:
456
+ self.activation = activation
457
+
458
+ if sparse:
459
+ self.cross_attn = MultiheadAttention(
460
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
461
+ auto_sparsity=sparsity if auto_sparsity else 0)
462
+ if not auto_sparsity:
463
+ self.__setattr__("mask", torch.zeros(1, 1))
464
+ self.mask_random_seed = mask_random_seed
465
+
466
+ def forward(self, q, k, mask=None):
467
+ """
468
+ Args:
469
+ q: tensor of shape (T, B, C)
470
+ k: tensor of shape (S, B, C)
471
+ mask: tensor of shape (T, S)
472
+
473
+ """
474
+ device = q.device
475
+ T, B, C = q.shape
476
+ S, B, C = k.shape
477
+ if self.sparse and not self.auto_sparsity:
478
+ assert mask is None
479
+ mask = self.mask
480
+ if mask.shape[-1] != S or mask.shape[-2] != T:
481
+ mask = get_mask(
482
+ S,
483
+ T,
484
+ self.mask_type,
485
+ self.sparse_attn_window,
486
+ self.global_window,
487
+ self.mask_random_seed,
488
+ self.sparsity,
489
+ device,
490
+ )
491
+ self.__setattr__("mask", mask)
492
+
493
+ if self.norm_first:
494
+ x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
495
+ x = x + self.gamma_2(self._ff_block(self.norm3(x)))
496
+ if self.norm_out:
497
+ x = self.norm_out(x)
498
+ else:
499
+ x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
500
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
501
+
502
+ return x
503
+
504
+ # self-attention block
505
+ def _ca_block(self, q, k, attn_mask=None):
506
+ x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
507
+ return self.dropout1(x)
508
+
509
+ # feed forward block
510
+ def _ff_block(self, x):
511
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
512
+ return self.dropout2(x)
513
+
514
+ def _get_activation_fn(self, activation):
515
+ if activation == "relu":
516
+ return F.relu
517
+ elif activation == "gelu":
518
+ return F.gelu
519
+
520
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
521
+
522
+
523
+ # ----------------- MULTI-BLOCKS MODELS: -----------------------
524
+
525
+
526
+ class CrossTransformerEncoder(nn.Module):
527
+ def __init__(
528
+ self,
529
+ dim: int,
530
+ emb: str = "sin",
531
+ hidden_scale: float = 4.0,
532
+ num_heads: int = 8,
533
+ num_layers: int = 6,
534
+ cross_first: bool = False,
535
+ dropout: float = 0.0,
536
+ max_positions: int = 1000,
537
+ norm_in: bool = True,
538
+ norm_in_group: bool = False,
539
+ group_norm: int = False,
540
+ norm_first: bool = False,
541
+ norm_out: bool = False,
542
+ max_period: float = 10000.0,
543
+ weight_decay: float = 0.0,
544
+ lr: tp.Optional[float] = None,
545
+ layer_scale: bool = False,
546
+ gelu: bool = True,
547
+ sin_random_shift: int = 0,
548
+ weight_pos_embed: float = 1.0,
549
+ cape_mean_normalize: bool = True,
550
+ cape_augment: bool = True,
551
+ cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
552
+ sparse_self_attn: bool = False,
553
+ sparse_cross_attn: bool = False,
554
+ mask_type: str = "diag",
555
+ mask_random_seed: int = 42,
556
+ sparse_attn_window: int = 500,
557
+ global_window: int = 50,
558
+ auto_sparsity: bool = False,
559
+ sparsity: float = 0.95,
560
+ ):
561
+ super().__init__()
562
+ """
563
+ """
564
+ assert dim % num_heads == 0
565
+
566
+ hidden_dim = int(dim * hidden_scale)
567
+
568
+ self.num_layers = num_layers
569
+ # classic parity = 1 means that if idx%2 == 1 there is a
570
+ # classical encoder else there is a cross encoder
571
+ self.classic_parity = 1 if cross_first else 0
572
+ self.emb = emb
573
+ self.max_period = max_period
574
+ self.weight_decay = weight_decay
575
+ self.weight_pos_embed = weight_pos_embed
576
+ self.sin_random_shift = sin_random_shift
577
+ if emb == "cape":
578
+ self.cape_mean_normalize = cape_mean_normalize
579
+ self.cape_augment = cape_augment
580
+ self.cape_glob_loc_scale = cape_glob_loc_scale
581
+ if emb == "scaled":
582
+ self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
583
+
584
+ self.lr = lr
585
+
586
+ activation: tp.Any = F.gelu if gelu else F.relu
587
+
588
+ self.norm_in: nn.Module
589
+ self.norm_in_t: nn.Module
590
+ if norm_in:
591
+ self.norm_in = nn.LayerNorm(dim)
592
+ self.norm_in_t = nn.LayerNorm(dim)
593
+ elif norm_in_group:
594
+ self.norm_in = MyGroupNorm(int(norm_in_group), dim)
595
+ self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
596
+ else:
597
+ self.norm_in = nn.Identity()
598
+ self.norm_in_t = nn.Identity()
599
+
600
+ # spectrogram layers
601
+ self.layers = nn.ModuleList()
602
+ # temporal layers
603
+ self.layers_t = nn.ModuleList()
604
+
605
+ kwargs_common = {
606
+ "d_model": dim,
607
+ "nhead": num_heads,
608
+ "dim_feedforward": hidden_dim,
609
+ "dropout": dropout,
610
+ "activation": activation,
611
+ "group_norm": group_norm,
612
+ "norm_first": norm_first,
613
+ "norm_out": norm_out,
614
+ "layer_scale": layer_scale,
615
+ "mask_type": mask_type,
616
+ "mask_random_seed": mask_random_seed,
617
+ "sparse_attn_window": sparse_attn_window,
618
+ "global_window": global_window,
619
+ "sparsity": sparsity,
620
+ "auto_sparsity": auto_sparsity,
621
+ "batch_first": True,
622
+ }
623
+
624
+ kwargs_classic_encoder = dict(kwargs_common)
625
+ kwargs_classic_encoder.update({
626
+ "sparse": sparse_self_attn,
627
+ })
628
+ kwargs_cross_encoder = dict(kwargs_common)
629
+ kwargs_cross_encoder.update({
630
+ "sparse": sparse_cross_attn,
631
+ })
632
+
633
+ for idx in range(num_layers):
634
+ if idx % 2 == self.classic_parity:
635
+
636
+ self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
637
+ self.layers_t.append(
638
+ MyTransformerEncoderLayer(**kwargs_classic_encoder)
639
+ )
640
+
641
+ else:
642
+ self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
643
+
644
+ self.layers_t.append(
645
+ CrossTransformerEncoderLayer(**kwargs_cross_encoder)
646
+ )
647
+
648
+ def forward(self, x, xt):
649
+ B, C, Fr, T1 = x.shape
650
+ pos_emb_2d = create_2d_sin_embedding(
651
+ C, Fr, T1, x.device, self.max_period
652
+ ) # (1, C, Fr, T1)
653
+ pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
654
+ x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
655
+ x = self.norm_in(x)
656
+ x = x + self.weight_pos_embed * pos_emb_2d
657
+
658
+ B, C, T2 = xt.shape
659
+ xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
660
+ pos_emb = self._get_pos_embedding(T2, B, C, x.device)
661
+ pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
662
+ xt = self.norm_in_t(xt)
663
+ xt = xt + self.weight_pos_embed * pos_emb
664
+
665
+ for idx in range(self.num_layers):
666
+ if idx % 2 == self.classic_parity:
667
+ x = self.layers[idx](x)
668
+ xt = self.layers_t[idx](xt)
669
+ else:
670
+ old_x = x
671
+ x = self.layers[idx](x, xt)
672
+ xt = self.layers_t[idx](xt, old_x)
673
+
674
+ x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
675
+ xt = rearrange(xt, "b t2 c -> b c t2")
676
+ return x, xt
677
+
678
+ def _get_pos_embedding(self, T, B, C, device):
679
+ if self.emb == "sin":
680
+ shift = random.randrange(self.sin_random_shift + 1)
681
+ pos_emb = create_sin_embedding(
682
+ T, C, shift=shift, device=device, max_period=self.max_period
683
+ )
684
+ elif self.emb == "cape":
685
+ if self.training:
686
+ pos_emb = create_sin_embedding_cape(
687
+ T,
688
+ C,
689
+ B,
690
+ device=device,
691
+ max_period=self.max_period,
692
+ mean_normalize=self.cape_mean_normalize,
693
+ augment=self.cape_augment,
694
+ max_global_shift=self.cape_glob_loc_scale[0],
695
+ max_local_shift=self.cape_glob_loc_scale[1],
696
+ max_scale=self.cape_glob_loc_scale[2],
697
+ )
698
+ else:
699
+ pos_emb = create_sin_embedding_cape(
700
+ T,
701
+ C,
702
+ B,
703
+ device=device,
704
+ max_period=self.max_period,
705
+ mean_normalize=self.cape_mean_normalize,
706
+ augment=False,
707
+ )
708
+
709
+ elif self.emb == "scaled":
710
+ pos = torch.arange(T, device=device)
711
+ pos_emb = self.position_embeddings(pos)[:, None]
712
+
713
+ return pos_emb
714
+
715
+ def make_optim_group(self):
716
+ group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
717
+ if self.lr is not None:
718
+ group["lr"] = self.lr
719
+ return group
720
+
721
+
722
+ # Attention Modules
723
+
724
+
725
+ class MultiheadAttention(nn.Module):
726
+ def __init__(
727
+ self,
728
+ embed_dim,
729
+ num_heads,
730
+ dropout=0.0,
731
+ bias=True,
732
+ add_bias_kv=False,
733
+ add_zero_attn=False,
734
+ kdim=None,
735
+ vdim=None,
736
+ batch_first=False,
737
+ auto_sparsity=None,
738
+ ):
739
+ super().__init__()
740
+ assert auto_sparsity is not None, "sanity check"
741
+ self.num_heads = num_heads
742
+ self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
743
+ self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
744
+ self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
745
+ self.attn_drop = torch.nn.Dropout(dropout)
746
+ self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
747
+ self.proj_drop = torch.nn.Dropout(dropout)
748
+ self.batch_first = batch_first
749
+ self.auto_sparsity = auto_sparsity
750
+
751
+ def forward(
752
+ self,
753
+ query,
754
+ key,
755
+ value,
756
+ key_padding_mask=None,
757
+ need_weights=True,
758
+ attn_mask=None,
759
+ average_attn_weights=True,
760
+ ):
761
+
762
+ if not self.batch_first: # N, B, C
763
+ query = query.permute(1, 0, 2) # B, N_q, C
764
+ key = key.permute(1, 0, 2) # B, N_k, C
765
+ value = value.permute(1, 0, 2) # B, N_k, C
766
+ B, N_q, C = query.shape
767
+ B, N_k, C = key.shape
768
+
769
+ q = (
770
+ self.q(query)
771
+ .reshape(B, N_q, self.num_heads, C // self.num_heads)
772
+ .permute(0, 2, 1, 3)
773
+ )
774
+ q = q.flatten(0, 1)
775
+ k = (
776
+ self.k(key)
777
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
778
+ .permute(0, 2, 1, 3)
779
+ )
780
+ k = k.flatten(0, 1)
781
+ v = (
782
+ self.v(value)
783
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
784
+ .permute(0, 2, 1, 3)
785
+ )
786
+ v = v.flatten(0, 1)
787
+
788
+ if self.auto_sparsity:
789
+ assert attn_mask is None
790
+ x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
791
+ else:
792
+ x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
793
+ x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
794
+
795
+ x = x.transpose(1, 2).reshape(B, N_q, C)
796
+ x = self.proj(x)
797
+ x = self.proj_drop(x)
798
+ if not self.batch_first:
799
+ x = x.permute(1, 0, 2)
800
+ return x, None
801
+
802
+
803
+ def scaled_query_key_softmax(q, k, att_mask):
804
+ from xformers.ops import masked_matmul
805
+ q = q / (k.size(-1)) ** 0.5
806
+ att = masked_matmul(q, k.transpose(-2, -1), att_mask)
807
+ att = torch.nn.functional.softmax(att, -1)
808
+ return att
809
+
810
+
811
+ def scaled_dot_product_attention(q, k, v, att_mask, dropout):
812
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
813
+ att = dropout(att)
814
+ y = att @ v
815
+ return y
816
+
817
+
818
+ def _compute_buckets(x, R):
819
+ qq = torch.einsum('btf,bfhi->bhti', x, R)
820
+ qq = torch.cat([qq, -qq], dim=-1)
821
+ buckets = qq.argmax(dim=-1)
822
+
823
+ return buckets.permute(0, 2, 1).byte().contiguous()
824
+
825
+
826
+ def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
827
+ # assert False, "The code for the custom sparse kernel is not ready for release yet."
828
+ from xformers.ops import find_locations, sparse_memory_efficient_attention
829
+ n_hashes = 32
830
+ proj_size = 4
831
+ query, key, value = [x.contiguous() for x in [query, key, value]]
832
+ with torch.no_grad():
833
+ R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
834
+ bucket_query = _compute_buckets(query, R)
835
+ bucket_key = _compute_buckets(key, R)
836
+ row_offsets, column_indices = find_locations(
837
+ bucket_query, bucket_key, sparsity, infer_sparsity)
838
+ return sparse_memory_efficient_attention(
839
+ query, key, value, row_offsets, column_indices, attn_bias)
demucs/utils.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from contextlib import contextmanager
9
+ import math
10
+ import os
11
+ import tempfile
12
+ import typing as tp
13
+
14
+ import errno
15
+ import functools
16
+ import hashlib
17
+ import inspect
18
+ import io
19
+ import os
20
+ import random
21
+ import socket
22
+ import tempfile
23
+ import warnings
24
+ import zlib
25
+ import tkinter as tk
26
+
27
+ from diffq import UniformQuantizer, DiffQuantizer
28
+ import torch as th
29
+ import tqdm
30
+ from torch import distributed
31
+ from torch.nn import functional as F
32
+
33
+ import torch
34
+
35
+ def unfold(a, kernel_size, stride):
36
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
37
+ with K the kernel size, by extracting frames with the given stride.
38
+
39
+ This will pad the input so that `F = ceil(T / K)`.
40
+
41
+ see https://github.com/pytorch/pytorch/issues/60466
42
+ """
43
+ *shape, length = a.shape
44
+ n_frames = math.ceil(length / stride)
45
+ tgt_length = (n_frames - 1) * stride + kernel_size
46
+ a = F.pad(a, (0, tgt_length - length))
47
+ strides = list(a.stride())
48
+ assert strides[-1] == 1, 'data should be contiguous'
49
+ strides = strides[:-1] + [stride, 1]
50
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
51
+
52
+
53
+ def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
54
+ """
55
+ Center trim `tensor` with respect to `reference`, along the last dimension.
56
+ `reference` can also be a number, representing the length to trim to.
57
+ If the size difference != 0 mod 2, the extra sample is removed on the right side.
58
+ """
59
+ ref_size: int
60
+ if isinstance(reference, torch.Tensor):
61
+ ref_size = reference.size(-1)
62
+ else:
63
+ ref_size = reference
64
+ delta = tensor.size(-1) - ref_size
65
+ if delta < 0:
66
+ raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
67
+ if delta:
68
+ tensor = tensor[..., delta // 2:-(delta - delta // 2)]
69
+ return tensor
70
+
71
+
72
+ def pull_metric(history: tp.List[dict], name: str):
73
+ out = []
74
+ for metrics in history:
75
+ metric = metrics
76
+ for part in name.split("."):
77
+ metric = metric[part]
78
+ out.append(metric)
79
+ return out
80
+
81
+
82
+ def EMA(beta: float = 1):
83
+ """
84
+ Exponential Moving Average callback.
85
+ Returns a single function that can be called to repeatidly update the EMA
86
+ with a dict of metrics. The callback will return
87
+ the new averaged dict of metrics.
88
+
89
+ Note that for `beta=1`, this is just plain averaging.
90
+ """
91
+ fix: tp.Dict[str, float] = defaultdict(float)
92
+ total: tp.Dict[str, float] = defaultdict(float)
93
+
94
+ def _update(metrics: dict, weight: float = 1) -> dict:
95
+ nonlocal total, fix
96
+ for key, value in metrics.items():
97
+ total[key] = total[key] * beta + weight * float(value)
98
+ fix[key] = fix[key] * beta + weight
99
+ return {key: tot / fix[key] for key, tot in total.items()}
100
+ return _update
101
+
102
+
103
+ def sizeof_fmt(num: float, suffix: str = 'B'):
104
+ """
105
+ Given `num` bytes, return human readable size.
106
+ Taken from https://stackoverflow.com/a/1094933
107
+ """
108
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
109
+ if abs(num) < 1024.0:
110
+ return "%3.1f%s%s" % (num, unit, suffix)
111
+ num /= 1024.0
112
+ return "%.1f%s%s" % (num, 'Yi', suffix)
113
+
114
+
115
+ @contextmanager
116
+ def temp_filenames(count: int, delete=True):
117
+ names = []
118
+ try:
119
+ for _ in range(count):
120
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
121
+ yield names
122
+ finally:
123
+ if delete:
124
+ for name in names:
125
+ os.unlink(name)
126
+
127
+ def average_metric(metric, count=1.):
128
+ """
129
+ Average `metric` which should be a float across all hosts. `count` should be
130
+ the weight for this particular host (i.e. number of examples).
131
+ """
132
+ metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda')
133
+ distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
134
+ return metric[1].item() / metric[0].item()
135
+
136
+
137
+ def free_port(host='', low=20000, high=40000):
138
+ """
139
+ Return a port number that is most likely free.
140
+ This could suffer from a race condition although
141
+ it should be quite rare.
142
+ """
143
+ sock = socket.socket()
144
+ while True:
145
+ port = random.randint(low, high)
146
+ try:
147
+ sock.bind((host, port))
148
+ except OSError as error:
149
+ if error.errno == errno.EADDRINUSE:
150
+ continue
151
+ raise
152
+ return port
153
+
154
+
155
+ def sizeof_fmt(num, suffix='B'):
156
+ """
157
+ Given `num` bytes, return human readable size.
158
+ Taken from https://stackoverflow.com/a/1094933
159
+ """
160
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
161
+ if abs(num) < 1024.0:
162
+ return "%3.1f%s%s" % (num, unit, suffix)
163
+ num /= 1024.0
164
+ return "%.1f%s%s" % (num, 'Yi', suffix)
165
+
166
+
167
+ def human_seconds(seconds, display='.2f'):
168
+ """
169
+ Given `seconds` seconds, return human readable duration.
170
+ """
171
+ value = seconds * 1e6
172
+ ratios = [1e3, 1e3, 60, 60, 24]
173
+ names = ['us', 'ms', 's', 'min', 'hrs', 'days']
174
+ last = names.pop(0)
175
+ for name, ratio in zip(names, ratios):
176
+ if value / ratio < 0.3:
177
+ break
178
+ value /= ratio
179
+ last = name
180
+ return f"{format(value, display)} {last}"
181
+
182
+
183
+ class TensorChunk:
184
+ def __init__(self, tensor, offset=0, length=None):
185
+ total_length = tensor.shape[-1]
186
+ assert offset >= 0
187
+ assert offset < total_length
188
+
189
+ if length is None:
190
+ length = total_length - offset
191
+ else:
192
+ length = min(total_length - offset, length)
193
+
194
+ self.tensor = tensor
195
+ self.offset = offset
196
+ self.length = length
197
+ self.device = tensor.device
198
+
199
+ @property
200
+ def shape(self):
201
+ shape = list(self.tensor.shape)
202
+ shape[-1] = self.length
203
+ return shape
204
+
205
+ def padded(self, target_length):
206
+ delta = target_length - self.length
207
+ total_length = self.tensor.shape[-1]
208
+ assert delta >= 0
209
+
210
+ start = self.offset - delta // 2
211
+ end = start + target_length
212
+
213
+ correct_start = max(0, start)
214
+ correct_end = min(total_length, end)
215
+
216
+ pad_left = correct_start - start
217
+ pad_right = end - correct_end
218
+
219
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
220
+ assert out.shape[-1] == target_length
221
+ return out
222
+
223
+
224
+ def tensor_chunk(tensor_or_chunk):
225
+ if isinstance(tensor_or_chunk, TensorChunk):
226
+ return tensor_or_chunk
227
+ else:
228
+ assert isinstance(tensor_or_chunk, th.Tensor)
229
+ return TensorChunk(tensor_or_chunk)
230
+
231
+
232
+ def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None):
233
+ """
234
+ Apply model to a given mixture.
235
+
236
+ Args:
237
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
238
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
239
+ all predictions are averaged. This effectively makes the model time equivariant
240
+ and improves SDR by up to 0.2 points.
241
+ split (bool): if True, the input will be broken down in 8 seconds extracts
242
+ and predictions will be performed individually on each and concatenated.
243
+ Useful for model with large memory footprint like Tasnet.
244
+ progress (bool): if True, show a progress bar (requires split=True)
245
+ """
246
+
247
+ channels, length = mix.size()
248
+ device = mix.device
249
+ progress_value = 0
250
+
251
+ if split:
252
+ out = th.zeros(4, channels, length, device=device)
253
+ shift = model.samplerate * 10
254
+ offsets = range(0, length, shift)
255
+ scale = 10
256
+ if progress:
257
+ offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
258
+ for offset in offsets:
259
+ chunk = mix[..., offset:offset + shift]
260
+ if set_progress_bar:
261
+ progress_value += 1
262
+ set_progress_bar(0.1, (0.8/len(offsets)*progress_value))
263
+ chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
264
+ else:
265
+ chunk_out = apply_model_v1(model, chunk, shifts=shifts)
266
+ out[..., offset:offset + shift] = chunk_out
267
+ offset += shift
268
+ return out
269
+ elif shifts:
270
+ max_shift = int(model.samplerate / 2)
271
+ mix = F.pad(mix, (max_shift, max_shift))
272
+ offsets = list(range(max_shift))
273
+ random.shuffle(offsets)
274
+ out = 0
275
+ for offset in offsets[:shifts]:
276
+ shifted = mix[..., offset:offset + length + max_shift]
277
+ if set_progress_bar:
278
+ shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar)
279
+ else:
280
+ shifted_out = apply_model_v1(model, shifted)
281
+ out += shifted_out[..., max_shift - offset:max_shift - offset + length]
282
+ out /= shifts
283
+ return out
284
+ else:
285
+ valid_length = model.valid_length(length)
286
+ delta = valid_length - length
287
+ padded = F.pad(mix, (delta // 2, delta - delta // 2))
288
+ with th.no_grad():
289
+ out = model(padded.unsqueeze(0))[0]
290
+ return center_trim(out, mix)
291
+
292
+ def apply_model_v2(model, mix, shifts=None, split=False,
293
+ overlap=0.25, transition_power=1., progress=False, set_progress_bar=None):
294
+ """
295
+ Apply model to a given mixture.
296
+
297
+ Args:
298
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
299
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
300
+ all predictions are averaged. This effectively makes the model time equivariant
301
+ and improves SDR by up to 0.2 points.
302
+ split (bool): if True, the input will be broken down in 8 seconds extracts
303
+ and predictions will be performed individually on each and concatenated.
304
+ Useful for model with large memory footprint like Tasnet.
305
+ progress (bool): if True, show a progress bar (requires split=True)
306
+ """
307
+
308
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
309
+ device = mix.device
310
+ channels, length = mix.shape
311
+ progress_value = 0
312
+
313
+ if split:
314
+ out = th.zeros(len(model.sources), channels, length, device=device)
315
+ sum_weight = th.zeros(length, device=device)
316
+ segment = model.segment_length
317
+ stride = int((1 - overlap) * segment)
318
+ offsets = range(0, length, stride)
319
+ scale = stride / model.samplerate
320
+ if progress:
321
+ offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
322
+ # We start from a triangle shaped weight, with maximal weight in the middle
323
+ # of the segment. Then we normalize and take to the power `transition_power`.
324
+ # Large values of transition power will lead to sharper transitions.
325
+ weight = th.cat([th.arange(1, segment // 2 + 1),
326
+ th.arange(segment - segment // 2, 0, -1)]).to(device)
327
+ assert len(weight) == segment
328
+ # If the overlap < 50%, this will translate to linear transition when
329
+ # transition_power is 1.
330
+ weight = (weight / weight.max())**transition_power
331
+ for offset in offsets:
332
+ chunk = TensorChunk(mix, offset, segment)
333
+ if set_progress_bar:
334
+ progress_value += 1
335
+ set_progress_bar(0.1, (0.8/len(offsets)*progress_value))
336
+ chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
337
+ else:
338
+ chunk_out = apply_model_v2(model, chunk, shifts=shifts)
339
+ chunk_length = chunk_out.shape[-1]
340
+ out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out
341
+ sum_weight[offset:offset + segment] += weight[:chunk_length]
342
+ offset += segment
343
+ assert sum_weight.min() > 0
344
+ out /= sum_weight
345
+ return out
346
+ elif shifts:
347
+ max_shift = int(0.5 * model.samplerate)
348
+ mix = tensor_chunk(mix)
349
+ padded_mix = mix.padded(length + 2 * max_shift)
350
+ out = 0
351
+ for _ in range(shifts):
352
+ offset = random.randint(0, max_shift)
353
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
354
+
355
+ if set_progress_bar:
356
+ progress_value += 1
357
+ shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar)
358
+ else:
359
+ shifted_out = apply_model_v2(model, shifted)
360
+ out += shifted_out[..., max_shift - offset:]
361
+ out /= shifts
362
+ return out
363
+ else:
364
+ valid_length = model.valid_length(length)
365
+ mix = tensor_chunk(mix)
366
+ padded_mix = mix.padded(valid_length)
367
+ with th.no_grad():
368
+ out = model(padded_mix.unsqueeze(0))[0]
369
+ return center_trim(out, length)
370
+
371
+
372
+ @contextmanager
373
+ def temp_filenames(count, delete=True):
374
+ names = []
375
+ try:
376
+ for _ in range(count):
377
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
378
+ yield names
379
+ finally:
380
+ if delete:
381
+ for name in names:
382
+ os.unlink(name)
383
+
384
+
385
+ def get_quantizer(model, args, optimizer=None):
386
+ quantizer = None
387
+ if args.diffq:
388
+ quantizer = DiffQuantizer(
389
+ model, min_size=args.q_min_size, group_size=8)
390
+ if optimizer is not None:
391
+ quantizer.setup_optimizer(optimizer)
392
+ elif args.qat:
393
+ quantizer = UniformQuantizer(
394
+ model, bits=args.qat, min_size=args.q_min_size)
395
+ return quantizer
396
+
397
+
398
+ def load_model(path, strict=False):
399
+ with warnings.catch_warnings():
400
+ warnings.simplefilter("ignore")
401
+ load_from = path
402
+ package = th.load(load_from, 'cpu')
403
+
404
+ klass = package["klass"]
405
+ args = package["args"]
406
+ kwargs = package["kwargs"]
407
+
408
+ if strict:
409
+ model = klass(*args, **kwargs)
410
+ else:
411
+ sig = inspect.signature(klass)
412
+ for key in list(kwargs):
413
+ if key not in sig.parameters:
414
+ warnings.warn("Dropping inexistant parameter " + key)
415
+ del kwargs[key]
416
+ model = klass(*args, **kwargs)
417
+
418
+ state = package["state"]
419
+ training_args = package["training_args"]
420
+ quantizer = get_quantizer(model, training_args)
421
+
422
+ set_state(model, quantizer, state)
423
+ return model
424
+
425
+
426
+ def get_state(model, quantizer):
427
+ if quantizer is None:
428
+ state = {k: p.data.to('cpu') for k, p in model.state_dict().items()}
429
+ else:
430
+ state = quantizer.get_quantized_state()
431
+ buf = io.BytesIO()
432
+ th.save(state, buf)
433
+ state = {'compressed': zlib.compress(buf.getvalue())}
434
+ return state
435
+
436
+
437
+ def set_state(model, quantizer, state):
438
+ if quantizer is None:
439
+ model.load_state_dict(state)
440
+ else:
441
+ buf = io.BytesIO(zlib.decompress(state["compressed"]))
442
+ state = th.load(buf, "cpu")
443
+ quantizer.restore_quantized_state(state)
444
+
445
+ return state
446
+
447
+
448
+ def save_state(state, path):
449
+ buf = io.BytesIO()
450
+ th.save(state, buf)
451
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
452
+
453
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
454
+ path.write_bytes(buf.getvalue())
455
+
456
+
457
+ def save_model(model, quantizer, training_args, path):
458
+ args, kwargs = model._init_args_kwargs
459
+ klass = model.__class__
460
+
461
+ state = get_state(model, quantizer)
462
+
463
+ save_to = path
464
+ package = {
465
+ 'klass': klass,
466
+ 'args': args,
467
+ 'kwargs': kwargs,
468
+ 'state': state,
469
+ 'training_args': training_args,
470
+ }
471
+ th.save(package, save_to)
472
+
473
+
474
+ def capture_init(init):
475
+ @functools.wraps(init)
476
+ def __init__(self, *args, **kwargs):
477
+ self._init_args_kwargs = (args, kwargs)
478
+ init(self, *args, **kwargs)
479
+
480
+ return __init__
481
+
482
+ class DummyPoolExecutor:
483
+ class DummyResult:
484
+ def __init__(self, func, *args, **kwargs):
485
+ self.func = func
486
+ self.args = args
487
+ self.kwargs = kwargs
488
+
489
+ def result(self):
490
+ return self.func(*self.args, **self.kwargs)
491
+
492
+ def __init__(self, workers=0):
493
+ pass
494
+
495
+ def submit(self, func, *args, **kwargs):
496
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
497
+
498
+ def __enter__(self):
499
+ return self
500
+
501
+ def __exit__(self, exc_type, exc_value, exc_tb):
502
+ return
gui_data/__pycache__/app_size_values.cpython-310.pyc ADDED
Binary file (8.52 kB). View file
 
gui_data/__pycache__/constants.cpython-310.pyc ADDED
Binary file (63.6 kB). View file
 
gui_data/__pycache__/error_handling.cpython-310.pyc ADDED
Binary file (5.67 kB). View file
 
gui_data/__pycache__/old_data_check.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
gui_data/app_size_values.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ from screeninfo import get_monitors
4
+ from PIL import Image
5
+ from PIL import ImageTk
6
+
7
+ OPERATING_SYSTEM = platform.system()
8
+
9
+ def get_screen_height():
10
+ monitors = get_monitors()
11
+ if len(monitors) == 0:
12
+ raise Exception("Failed to get screen height")
13
+ return monitors[0].height, monitors[0].width
14
+
15
+ def scale_values(value):
16
+ if not SCALE_WIN_SIZE == 1920:
17
+ ratio = SCALE_WIN_SIZE/1920 # Approx. 1.3333 for 2K
18
+ return value * ratio
19
+ else:
20
+ return value
21
+
22
+ SCREEN_HIGHT, SCREEN_WIDTH = get_screen_height()
23
+ SCALE_WIN_SIZE = 1920
24
+
25
+ SCREEN_SIZE_VALUES = {
26
+ "normal": {
27
+ "credits_img":(100, 100),
28
+ ## App Size
29
+ 'IMAGE_HEIGHT': 140,
30
+ 'FILEPATHS_HEIGHT': 75,
31
+ 'OPTIONS_HEIGHT': 262,
32
+ 'CONVERSIONBUTTON_HEIGHT': 30,
33
+ 'COMMAND_HEIGHT': 141,
34
+ 'PROGRESS_HEIGHT': 25,
35
+ 'PADDING': 7,
36
+ 'WIDTH': 680
37
+ },
38
+ "small": {
39
+ "credits_img":(50, 50),
40
+ 'IMAGE_HEIGHT': 140,
41
+ 'FILEPATHS_HEIGHT': 75,
42
+ 'OPTIONS_HEIGHT': 262,
43
+ 'CONVERSIONBUTTON_HEIGHT': 30,
44
+ 'COMMAND_HEIGHT': 80,
45
+ 'PROGRESS_HEIGHT': 25,
46
+ 'PADDING': 5,
47
+ 'WIDTH': 680
48
+ },
49
+ "medium": {
50
+ "credits_img":(50, 50),
51
+ ## App Size
52
+ 'IMAGE_HEIGHT': 140,
53
+ 'FILEPATHS_HEIGHT': 75,
54
+ 'OPTIONS_HEIGHT': 262,
55
+ 'CONVERSIONBUTTON_HEIGHT': 30,
56
+ 'COMMAND_HEIGHT': 115,
57
+ 'PROGRESS_HEIGHT': 25,
58
+ 'PADDING': 7,
59
+ 'WIDTH': 680
60
+ },
61
+ }
62
+
63
+ try:
64
+ if SCREEN_HIGHT >= 900:
65
+ determined_size = SCREEN_SIZE_VALUES["normal"]
66
+ elif SCREEN_HIGHT <= 720:
67
+ determined_size = SCREEN_SIZE_VALUES["small"]
68
+ else:
69
+ determined_size = SCREEN_SIZE_VALUES["medium"]
70
+ except:
71
+ determined_size = SCREEN_SIZE_VALUES["normal"]
72
+
73
+ image_scale_1, image_scale_2 = 20, 30
74
+
75
+ class ImagePath():
76
+ def __init__(self, base_path):
77
+ img_path = os.path.join(base_path, 'gui_data', 'img')
78
+ credits_path = os.path.join(img_path, 'credits.png')
79
+ donate_path = os.path.join(img_path, 'donate.png')
80
+ download_path = os.path.join(img_path, 'download.png')
81
+ efile_path = os.path.join(img_path, 'File.png')
82
+ help_path = os.path.join(img_path, 'help.png')
83
+ key_path = os.path.join(img_path, 'key.png')
84
+ stop_path = os.path.join(img_path, 'stop.png')
85
+ play_path = os.path.join(img_path, 'play.png')
86
+ pause_path = os.path.join(img_path, 'pause.png')
87
+ up_img_path = os.path.join(img_path, "up.png")
88
+ down_img_path = os.path.join(img_path, "down.png")
89
+ left_img_path = os.path.join(img_path, "left.png")
90
+ right_img_path = os.path.join(img_path, "right.png")
91
+ clear_img_path = os.path.join(img_path, "clear.png")
92
+ copy_img_path = os.path.join(img_path, "copy.png")
93
+ self.banner_path = os.path.join(img_path, 'UVR-banner.png')
94
+
95
+ self.efile_img = self.open_image(path=efile_path,size=(image_scale_1, image_scale_1))
96
+ self.stop_img = self.open_image(path=stop_path, size=(image_scale_1, image_scale_1))
97
+ self.play_img = self.open_image(path=play_path, size=(image_scale_1, image_scale_1))
98
+ self.pause_img = self.open_image(path=pause_path, size=(image_scale_1, image_scale_1))
99
+ self.help_img = self.open_image(path=help_path, size=(image_scale_1, image_scale_1))
100
+ self.download_img = self.open_image(path=download_path, size=(image_scale_2, image_scale_2))
101
+ self.donate_img = self.open_image(path=donate_path, size=(image_scale_2, image_scale_2))
102
+ self.key_img = self.open_image(path=key_path, size=(image_scale_2, image_scale_2))
103
+ self.up_img = self.open_image(path=up_img_path, size=(image_scale_2, image_scale_2))
104
+ self.down_img = self.open_image(path=down_img_path, size=(image_scale_2, image_scale_2))
105
+ self.left_img = self.open_image(path=left_img_path, size=(image_scale_2, image_scale_2))
106
+ self.right_img = self.open_image(path=right_img_path, size=(image_scale_2, image_scale_2))
107
+ self.clear_img = self.open_image(path=clear_img_path, size=(image_scale_2, image_scale_2))
108
+ self.copy_img = self.open_image(path=copy_img_path, size=(image_scale_2, image_scale_2))
109
+ self.credits_img = self.open_image(path=credits_path, size=determined_size["credits_img"])
110
+
111
+ def open_image(self, path: str, size: tuple = None, keep_aspect: bool = True, rotate: int = 0) -> ImageTk.PhotoImage:
112
+ """
113
+ Open the image on the path and apply given settings\n
114
+ Paramaters:
115
+ path(str):
116
+ Absolute path of the image
117
+ size(tuple):
118
+ first value - width
119
+ second value - height
120
+ keep_aspect(bool):
121
+ keep aspect ratio of image and resize
122
+ to maximum possible width and height
123
+ (maxima are given by size)
124
+ rotate(int):
125
+ clockwise rotation of image
126
+ Returns(ImageTk.PhotoImage):
127
+ Image of path
128
+ """
129
+ img = Image.open(path).convert(mode='RGBA')
130
+ ratio = img.height/img.width
131
+ img = img.rotate(angle=-rotate)
132
+ if size is not None:
133
+ size = (int(size[0]), int(size[1]))
134
+ if keep_aspect:
135
+ img = img.resize((size[0], int(size[0] * ratio)), Image.ANTIALIAS)
136
+ else:
137
+ img = img.resize(size, Image.ANTIALIAS)
138
+
139
+ return ImageTk.PhotoImage(img)
140
+
141
+ #All Sizes Below Calibrated to 1080p!
142
+
143
+ if OPERATING_SYSTEM=="Darwin":
144
+ FONT_SIZE_F1 = 13
145
+ FONT_SIZE_F2 = 11
146
+ FONT_SIZE_F3 = 12
147
+ FONT_SIZE_0 = 9
148
+ FONT_SIZE_1 = 11
149
+ FONT_SIZE_2 = 12
150
+ FONT_SIZE_3 = 13
151
+ FONT_SIZE_4 = 14
152
+ FONT_SIZE_5 = 15
153
+ FONT_SIZE_6 = 17
154
+ HELP_HINT_CHECKBOX_WIDTH = 13
155
+ MDX_CHECKBOXS_WIDTH = 14
156
+ VR_CHECKBOXS_WIDTH = 14
157
+ ENSEMBLE_CHECKBOXS_WIDTH = 18
158
+ DEMUCS_CHECKBOXS_WIDTH = 14
159
+ DEMUCS_PRE_CHECKBOXS_WIDTH = 20
160
+ GEN_SETTINGS_WIDTH = 17
161
+ MENU_COMBOBOX_WIDTH = 16
162
+ MENU_OPTION_WIDTH = 12
163
+ READ_ONLY_COMBO_WIDTH = 35
164
+ SETTINGS_BUT_WIDTH = 19
165
+ VR_BUT_WIDTH = 16
166
+ SET_MENUS_CHECK_WIDTH = 12
167
+ COMBO_WIDTH = 14
168
+ SET_VOC_SPLIT_CHECK_WIDTH = 21
169
+ elif OPERATING_SYSTEM=="Linux":
170
+ HELP_HINT_CHECKBOX_WIDTH = 15
171
+ MDX_CHECKBOXS_WIDTH = 16
172
+ VR_CHECKBOXS_WIDTH = 16
173
+ ENSEMBLE_CHECKBOXS_WIDTH = 20
174
+ DEMUCS_CHECKBOXS_WIDTH = 16
175
+ DEMUCS_PRE_CHECKBOXS_WIDTH = 24
176
+ GEN_SETTINGS_WIDTH = 20
177
+ MENU_COMBOBOX_WIDTH = 18
178
+ MENU_OPTION_WIDTH = 12
179
+ READ_ONLY_COMBO_WIDTH = 40
180
+ SETTINGS_BUT_WIDTH = 23
181
+ VR_BUT_WIDTH = 18
182
+ SET_MENUS_CHECK_WIDTH = 13
183
+ COMBO_WIDTH = 16
184
+ SET_VOC_SPLIT_CHECK_WIDTH = 25
185
+ FONT_SIZE_F1 = 10
186
+ FONT_SIZE_F2 = 8
187
+ FONT_SIZE_F3 = 9
188
+ FONT_SIZE_0 = 7
189
+ FONT_SIZE_1 = 8
190
+ FONT_SIZE_2 = 9
191
+ FONT_SIZE_3 = 10
192
+ FONT_SIZE_4 = 11
193
+ FONT_SIZE_5 = 13
194
+ FONT_SIZE_6 = 15
195
+ elif OPERATING_SYSTEM=="Windows":
196
+ HELP_HINT_CHECKBOX_WIDTH = 15
197
+ MDX_CHECKBOXS_WIDTH = 14
198
+ VR_CHECKBOXS_WIDTH = 14
199
+ ENSEMBLE_CHECKBOXS_WIDTH = 20
200
+ DEMUCS_CHECKBOXS_WIDTH = 14
201
+ DEMUCS_PRE_CHECKBOXS_WIDTH = 20
202
+ GEN_SETTINGS_WIDTH = 18
203
+ MENU_COMBOBOX_WIDTH = 16
204
+ MENU_OPTION_WIDTH = 12
205
+ READ_ONLY_COMBO_WIDTH = 35
206
+ SETTINGS_BUT_WIDTH = 20
207
+ VR_BUT_WIDTH = 16
208
+ SET_MENUS_CHECK_WIDTH = 13
209
+ COMBO_WIDTH = 14
210
+ SET_VOC_SPLIT_CHECK_WIDTH = 23
211
+ FONT_SIZE_F1 = 10
212
+ FONT_SIZE_F2 = 8
213
+ FONT_SIZE_F3 = 9
214
+ FONT_SIZE_0 = 7
215
+ FONT_SIZE_1 = 8
216
+ FONT_SIZE_2 = 9
217
+ FONT_SIZE_3 = 10
218
+ FONT_SIZE_4 = 11
219
+ FONT_SIZE_5 = 13
220
+ FONT_SIZE_6 = 15
221
+
222
+ #Main Size Values:
223
+ IMAGE_HEIGHT = determined_size["IMAGE_HEIGHT"]
224
+ FILEPATHS_HEIGHT = determined_size["FILEPATHS_HEIGHT"]
225
+ OPTIONS_HEIGHT = determined_size["OPTIONS_HEIGHT"]
226
+ CONVERSIONBUTTON_HEIGHT = determined_size["CONVERSIONBUTTON_HEIGHT"]
227
+ COMMAND_HEIGHT = determined_size["COMMAND_HEIGHT"]
228
+ PROGRESS_HEIGHT = determined_size["PROGRESS_HEIGHT"]
229
+ PADDING = determined_size["PADDING"]
230
+ WIDTH = determined_size["WIDTH"]
231
+
232
+ # IMAGE_HEIGHT = 140
233
+ # FILEPATHS_HEIGHT = 75
234
+ # OPTIONS_HEIGHT = 262
235
+ # CONVERSIONBUTTON_HEIGHT = 30
236
+ # COMMAND_HEIGHT = 141
237
+ # PROGRESS_HEIGHT = 25
238
+ # PADDING = 7
239
+ # WIDTH = 680
240
+
241
+ MENU_PADDING_1 = 3
242
+ MENU_PADDING_2 = 10
243
+ MENU_PADDING_3 = 15
244
+ MENU_PADDING_4 = 3
245
+
246
+ #Main Frame Sizes
247
+ X_CONVERSION_BUTTON_1080P = 50
248
+ WIDTH_CONVERSION_BUTTON_1080P = -100
249
+ HEIGHT_GENERIC_BUTTON_1080P = 35
250
+ X_STOP_BUTTON_1080P = -10 - 35
251
+ X_SETTINGS_BUTTON_1080P = -670
252
+ X_PROGRESSBAR_1080P = 25
253
+ WIDTH_PROGRESSBAR_1080P = -50
254
+ X_CONSOLE_FRAME_1080P = 15
255
+ WIDTH_CONSOLE_FRAME_1080P = -30
256
+ HO_S = 7
257
+
258
+ #File Frame Sizes
259
+ FILEPATHS_FRAME_X = 10
260
+ FILEPATHS_FRAME_Y = 155
261
+ FILEPATHS_FRAME_WIDTH = -20
262
+ MUSICFILE_BUTTON_X = 0
263
+ MUSICFILE_BUTTON_Y = 5
264
+ MUSICFILE_BUTTON_WIDTH = 0
265
+ MUSICFILE_BUTTON_HEIGHT = -5
266
+ MUSICFILE_ENTRY_X = 7.5
267
+ MUSICFILE_ENTRY_WIDTH = -50
268
+ MUSICFILE_ENTRY_HEIGHT = -5
269
+ MUSICFILE_OPEN_X = -45
270
+ MUSICFILE_OPEN_Y = 160
271
+ MUSICFILE_OPEN_WIDTH = 35
272
+ MUSICFILE_OPEN_HEIGHT = 33
273
+ SAVETO_BUTTON_X = 0
274
+ SAVETO_BUTTON_Y = 5
275
+ SAVETO_BUTTON_WIDTH = 0
276
+ SAVETO_BUTTON_HEIGHT = -5
277
+ SAVETO_ENTRY_X = 7.5
278
+ OPEN_BUTTON_X = 427.1
279
+ OPEN_BUTTON_WIDTH = -427.4
280
+ SAVETO_ENTRY_WIDTH = -50
281
+ SAVETO_ENTRY_HEIGHT = -5
282
+ SAVETO_OPEN_X = -45
283
+ SAVETO_OPEN_Y = 197.5
284
+ SAVETO_OPEN_WIDTH = 35
285
+ SAVETO_OPEN_HEIGHT = 32
286
+
287
+ #Main Option menu
288
+ OPTIONS_FRAME_X = 10
289
+ OPTIONS_FRAME_Y = 250
290
+ OPTIONS_FRAME_WIDTH = -20
291
+ FILEONE_LABEL_X = -28
292
+ FILEONE_LABEL_WIDTH = -38
293
+ FILETWO_LABEL_X = -32
294
+ FILETWO_LABEL_WIDTH = -20
295
+ TIME_WINDOW_LABEL_X = -43
296
+ TIME_WINDOW_LABEL_WIDTH = 0
297
+ INTRO_ANALYSIS_LABEL_X = -83
298
+ INTRO_ANALYSIS_LABEL_WIDTH = -50
299
+ INTRO_ANALYSIS_OPTION_X = -68
300
+ DB_ANALYSIS_LABEL_X = 62
301
+ DB_ANALYSIS_LABEL_WIDTH = -34
302
+ DB_ANALYSIS_OPTION_X = 86
303
+ WAV_TYPE_SET_LABEL_X = -43
304
+ WAV_TYPE_SET_LABEL_WIDTH = 0
305
+ ENTRY_WIDTH = 222
306
+
307
+ # Constants for the ensemble_listbox_Frame
308
+ ENSEMBLE_LISTBOX_FRAME_X = -25
309
+ ENSEMBLE_LISTBOX_FRAME_Y = -20
310
+ ENSEMBLE_LISTBOX_FRAME_WIDTH = 0
311
+ ENSEMBLE_LISTBOX_FRAME_HEIGHT = 67
312
+
313
+ # Constants for the ensemble_listbox_scroll
314
+ ENSEMBLE_LISTBOX_SCROLL_X = 195
315
+ ENSEMBLE_LISTBOX_SCROLL_Y = -20
316
+ ENSEMBLE_LISTBOX_SCROLL_WIDTH = -48
317
+ ENSEMBLE_LISTBOX_SCROLL_HEIGHT = 69
318
+
319
+ # Constants for Radio Buttons
320
+ RADIOBUTTON_X_WAV = 457
321
+ RADIOBUTTON_X_FLAC = 300
322
+ RADIOBUTTON_X_MP3 = 143
323
+ RADIOBUTTON_Y = -5
324
+ RADIOBUTTON_WIDTH = 0
325
+ RADIOBUTTON_HEIGHT = 6
326
+ MAIN_ROW_Y_1 = -15
327
+ MAIN_ROW_Y_2 = -17
328
+ MAIN_ROW_X_1 = -4
329
+ MAIN_ROW_X_2 = 21
330
+ MAIN_ROW_2_Y_1 = -15
331
+ MAIN_ROW_2_Y_2 = -17
332
+ MAIN_ROW_2_X_1 = -28
333
+ MAIN_ROW_2_X_2 = 1
334
+ LOW_MENU_Y_1 = 18
335
+ LOW_MENU_Y_2 = 16
336
+ SUB_ENT_ROW_X = -2
337
+ MAIN_ROW_WIDTH = -53
338
+ MAIN_ROW_ALIGN_WIDTH = -86
339
+ CHECK_BOX_Y = 0
340
+ CHECK_BOX_X = 20
341
+ CHECK_BOX_WIDTH = -49
342
+ CHECK_BOX_HEIGHT = 2
343
+ LEFT_ROW_WIDTH = -10
344
+ LABEL_HEIGHT = -5
345
+ OPTION_HEIGHT = 8
346
+ LABEL_X_OFFSET = -28
347
+ LABEL_WIDTH = -38
348
+ ENTRY_WIDTH = 179.5
349
+ ENTRY_OPEN_BUTT_WIDTH = -185
350
+ ENTRY_OPEN_BUTT_X_OFF = 405
351
+ UPDATE_LABEL_WIDTH = 35 if OPERATING_SYSTEM == 'Linux' else 32
352
+
353
+ HEIGHT_CONSOLE_FRAME_1080P = COMMAND_HEIGHT + HO_S
354
+ LOW_MENU_Y = LOW_MENU_Y_1, LOW_MENU_Y_2
355
+ MAIN_ROW_Y = MAIN_ROW_Y_1, MAIN_ROW_Y_2
356
+ MAIN_ROW_X = MAIN_ROW_X_1, MAIN_ROW_X_2
357
+ MAIN_ROW_2_Y = MAIN_ROW_2_Y_1, MAIN_ROW_2_Y_2
358
+ MAIN_ROW_2_X = MAIN_ROW_2_X_1, MAIN_ROW_2_X_2
359
+
360
+ LABEL_Y = MAIN_ROW_Y[0]
361
+ ENTRY_Y = MAIN_ROW_Y[1]
362
+
363
+ BUTTON_Y_1080P = IMAGE_HEIGHT + FILEPATHS_HEIGHT + OPTIONS_HEIGHT - 8 + PADDING*2
364
+ HEIGHT_PROGRESSBAR_1080P = PROGRESS_HEIGHT
365
+ Y_OFFSET_PROGRESS_BAR_1080P = IMAGE_HEIGHT + FILEPATHS_HEIGHT + OPTIONS_HEIGHT + CONVERSIONBUTTON_HEIGHT + COMMAND_HEIGHT + PADDING*4
366
+ Y_OFFSET_CONSOLE_FRAME_1080P = IMAGE_HEIGHT + FILEPATHS_HEIGHT + OPTIONS_HEIGHT + CONVERSIONBUTTON_HEIGHT + PADDING + X_PROGRESSBAR_1080P
367
+
368
+ LABEL_Y_OFFSET = MAIN_ROW_Y[0]
369
+ ENTRY_X_OFFSET = SUB_ENT_ROW_X
370
+ ENTRY_Y_OFFSET = MAIN_ROW_Y[1]
371
+ OPTION_WIDTH = MAIN_ROW_ALIGN_WIDTH
gui_data/change_log.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Most Recent Changes:
2
+
3
+ ~ Fixed Download Center model list issue.
4
+ ~ Fixed audio clip in ensemble mode.
5
+ ~ Fixed output model name issue in ensemble mode.
6
+ ~ Added "Batch Mode" for MDX-Net to increase performance.
7
+ ~ Batch Mode is more memory efficient.
8
+ ~ Batch Mode produces the best output, regardless of batch size.
9
+ ~ Added Batch Mode for VR Architecture.
10
+ ~ Added Mixer Mode for Demucs.
11
+ ~ This option may improve separation for some 4-stem models.
12
+
13
+ Fixes & Changes going from UVR v5.4 to v5.5:
14
+
15
+ ~ The progress bar is now fully synced up with every process in the application.
16
+ ~ Fixed low-resolution icon.
17
+ ~ Added the ability to download models manually if the application can't connect
18
+ to the internet.
19
+ ~ Drag-n-drop is functional across all os platforms.
20
+ ~ Resolved mp3 tag issue in MacOS version.
21
+
22
+ Performance:
23
+
24
+ ~ Model load times are faster.
25
+ ~ Importing/exporting audio files is faster.
26
+
27
+ MacOS M1 Notes:
28
+
29
+ ~ The GPU Conversion checkbox will enable MPS for GPU acceleration. However,
30
+ only the VR Architecture models are currently compatible with it.
31
+
32
+ New Options:
33
+
34
+ ~ Select Saved Settings option - Allows the user to save the current settings
35
+ of the whole application. You can also load a saved setting or reset them to
36
+ the default.
37
+ ~ Right-click menu - Allows for quick access to important options.
38
+ ~ Help Hints option - When enabled, users can hover over options to see a pop-up
39
+ text that describes that option. The right-clicking option also allows copying
40
+ the "Help Hint" text.
41
+ ~ Secondary Model Mode - This option is an expanded version of the "Demucs Model"
42
+ option that was only available to MDX-Net. Except now, this option is available
43
+ in all three AI Networks and for any stem. Any model can now be Secondary, and
44
+ the user can choose the amount of influence it has on the final result.
45
+ ~ Robust caching for ensemble mode, allowing for much faster processing times.
46
+ ~ Clicking the "Input" field will pop up a window allowing the user to review the selected audio inputs. Within this menu, users can:
47
+ ~ Remove inputs.
48
+ ~ Verify inputs.
49
+ ~ Create samples of chosen inputs.
50
+ ~ "Sample Mode" option - Allows the user to process only part of a track to sample
51
+ settings or a model without running a full conversion.
52
+ ~ The number in the parentheses is the current number of seconds the generated
53
+ sample will be.
54
+ ~ You can choose the number of seconds to extract from the track in the "Additional
55
+ Settings" menu.
56
+
57
+ VR Architecture:
58
+
59
+ ~ Ability to toggle "High-End Processing."
60
+ ~ Ability to change the post-processing threshold.
61
+ ~ Support for the latest VR architecture
62
+ ~ Crop Size and Batch Size are specifically for models using the latest
63
+ architecture only.
64
+
65
+ MDX-NET:
66
+
67
+ ~ Denoise Output option results in cleaner results,
68
+ but the processing time will be longer. This option has replaced Noise Reduction.
69
+ ~ Spectral Inversion option uses spectral inversion techniques for a
70
+ cleaner secondary stem result. This option may slow down the audio export process.
71
+ ~ Secondary stem now has the same frequency cut-off as the main stem.
72
+
73
+ Demucs:
74
+
75
+ ~ Demucs v4 models are now supported, including the 6-stem model.
76
+ ~ Ability to combine remaining stems instead of inverting selected stem with the
77
+ mixture only when a user does not select "All Stems".
78
+ ~ A Pre-process model that allows the user to run an inference through a robust
79
+ vocal or instrumental model and separate the remaining stems from its generated
80
+ instrumental mix. This option can significantly reduce vocal bleed in other
81
+ Demucs-generated non-vocal stems.
82
+ ~ The Pre-process model is intended for Demucs separations for all stems except
83
+ vocals and instrumentals.
84
+
85
+ Ensemble Mode:
86
+
87
+ ~ Ensemble Mode has been extended to include the following:
88
+ ~ Averaging is a new algorithm that averages the final results.
89
+ ~ Unlimited models in the ensemble.
90
+ ~ Ability to save different ensembles.
91
+ ~ Ability to ensemble outputs for all individual stem types.
92
+ ~ Ability to choose unique ensemble algorithms.
93
+ ~ Ability to ensemble all 4 Demucs stems at once.
gui_data/complete_chime.wav ADDED
Binary file (357 kB). View file
 
gui_data/constants.py ADDED
@@ -0,0 +1,1584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+
3
+ #Platform Details
4
+ OPERATING_SYSTEM = platform.system()
5
+ SYSTEM_ARCH = platform.platform()
6
+ SYSTEM_PROC = platform.processor()
7
+ ARM = 'arm'
8
+
9
+ is_macos = False
10
+
11
+ CPU = 'cpu'
12
+ CUDA_DEVICE = 'cuda'
13
+ DIRECTML_DEVICE = "privateuseone"
14
+
15
+ #MAIN_FONT_NAME = "Century Gothic"
16
+ OPT_SEPARATOR_SAVE = '─'*25
17
+ BG_COLOR = '#0e0e0f'
18
+ FG_COLOR = '#13849f'
19
+
20
+ #Model Types
21
+ VR_ARCH_TYPE = 'VR Arc'
22
+ MDX_ARCH_TYPE = 'MDX-Net'
23
+ DEMUCS_ARCH_TYPE = 'Demucs'
24
+ VR_ARCH_PM = 'VR Architecture'
25
+ ENSEMBLE_MODE = 'Ensemble Mode'
26
+ ENSEMBLE_STEM_CHECK = 'Ensemble Stem'
27
+ SECONDARY_MODEL = 'Secondary Model'
28
+ DEMUCS_6_STEM_MODEL = 'htdemucs_6s'
29
+ DEFAULT = "Default"
30
+ ALIGNMENT_TOOL = 'Alignment Tool Options'
31
+
32
+ SINGLE_FILE = 'SINGLE_FILE'
33
+ MULTIPLE_FILE = 'MULTI_FILE'
34
+ MAIN_MULTIPLE_FILE = 'MAIN_MULTI_FILE'
35
+ CHOOSE_EXPORT_FIR = 'CHOOSE_EXPORT_FIR'
36
+
37
+ DUAL = "dual"
38
+ FOUR_STEM = "fourstem"
39
+ ANY_STEM = "Any Stem"
40
+
41
+ DEMUCS_V3_ARCH_TYPE = 'Demucs v3'
42
+ DEMUCS_V4_ARCH_TYPE = 'Demucs v4'
43
+ DEMUCS_NEWER_ARCH_TYPES = [DEMUCS_V3_ARCH_TYPE, DEMUCS_V4_ARCH_TYPE]
44
+
45
+ DEMUCS_V1 = 'v1'
46
+ DEMUCS_V2 = 'v2'
47
+ DEMUCS_V3 = 'v3'
48
+ DEMUCS_V4 = 'v4'
49
+
50
+ DEMUCS_V1_TAG = 'v1 | '
51
+ DEMUCS_V2_TAG = 'v2 | '
52
+ DEMUCS_V3_TAG = 'v3 | '
53
+ DEMUCS_V4_TAG = 'v4 | '
54
+ DEMUCS_NEWER_TAGS = [DEMUCS_V3_TAG, DEMUCS_V4_TAG]
55
+
56
+ DEMUCS_VERSION_MAPPER = {
57
+ DEMUCS_V1:DEMUCS_V1_TAG,
58
+ DEMUCS_V2:DEMUCS_V2_TAG,
59
+ DEMUCS_V3:DEMUCS_V3_TAG,
60
+ DEMUCS_V4:DEMUCS_V4_TAG}
61
+
62
+ #Download Center
63
+ DOWNLOAD_FAILED = 'Download Failed'
64
+ DOWNLOAD_STOPPED = 'Download Stopped'
65
+ DOWNLOAD_COMPLETE = 'Download Complete'
66
+ DOWNLOAD_UPDATE_COMPLETE = 'Update Download Complete'
67
+ SETTINGS_MENU_EXIT = 'exit'
68
+ NO_CONNECTION = 'No Internet Connection'
69
+ VIP_SELECTION = 'VIP:'
70
+ DEVELOPER_SELECTION = 'VIP:'
71
+ NO_NEW_MODELS = 'All Available Models Downloaded'
72
+ ENSEMBLE_PARTITION = ': '
73
+ NO_MODEL = 'No Model Selected'
74
+ CHOOSE_MODEL = 'Choose Model'
75
+ SINGLE_DOWNLOAD = 'Downloading Item 1/1...'
76
+ DOWNLOADING_ITEM = 'Downloading Item'
77
+ FILE_EXISTS = 'File already exists!'
78
+ DOWNLOADING_UPDATE = 'Downloading Update...'
79
+ DOWNLOAD_MORE = 'Download More Models'
80
+ IS_KARAOKEE = "is_karaoke"
81
+ IS_BV_MODEL = "is_bv_model"
82
+ IS_BV_MODEL_REBAL = "is_bv_model_rebalanced"
83
+ INPUT_STEM_NAME = 'Input Stem Name'
84
+
85
+ #Menu Options
86
+
87
+ AUTO_SELECT = 'Auto'
88
+
89
+ #LINKS
90
+ DOWNLOAD_CHECKS = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
91
+ MDX_MODEL_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data_new.json"
92
+ VR_MODEL_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/vr_model_data/model_data_new.json"
93
+ MDX23_CONFIG_CHECKS = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/"
94
+ BULLETIN_CHECK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/bulletin.txt"
95
+
96
+ DEMUCS_MODEL_NAME_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/demucs_model_data/model_name_mapper.json"
97
+ MDX_MODEL_NAME_DATA_LINK = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_name_mapper.json"
98
+
99
+ DONATE_LINK_BMAC = "https://www.buymeacoffee.com/uvr5"
100
+ DONATE_LINK_PATREON = "https://www.patreon.com/uvr"
101
+
102
+ #DOWNLOAD REPOS
103
+ NORMAL_REPO = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/"
104
+ UPDATE_REPO = "https://github.com/TRvlvr/model_repo/releases/download/uvr_update_patches/"
105
+
106
+ UPDATE_MAC_ARM_REPO = "https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/Ultimate_Vocal_Remover_v5_6_MacOS_arm64.dmg"
107
+ UPDATE_MAC_X86_64_REPO = "https://github.com/Anjok07/ultimatevocalremovergui/releases/download/v5.6/Ultimate_Vocal_Remover_v5_6_MacOS_x86_64.dmg"
108
+ UPDATE_LINUX_REPO = "https://github.com/Anjok07/ultimatevocalremovergui#linux-installation"
109
+
110
+ ISSUE_LINK = 'https://github.com/Anjok07/ultimatevocalremovergui/issues/new'
111
+ VIP_REPO = b'\xf3\xc2W\x19\x1foI)\xc2\xa9\xcc\xb67(Z\xf5',\
112
+ b'gAAAAABjQAIQ-NpNMMxMedpKHHb7ze_nqB05hw0YhbOy3pFzuzDrfqumn8_qvraxEoUpZC5ZXC0gGvfDxFMqyq9VWbYKlA67SUFI_wZB6QoVyGI581vs7kaGfUqlXHIdDS6tQ_U-BfjbEAK9EU_74-R2zXjz8Xzekw=='
113
+ NO_CODE = 'incorrect_code'
114
+
115
+ #Extensions
116
+ ONNX = '.onnx'
117
+ CKPT = '.ckpt'
118
+ CKPT_C = '.ckptc'
119
+ YAML = '.yaml'
120
+ PTH = '.pth'
121
+ TH_EXT = '.th'
122
+ JSON = '.json'
123
+
124
+ #GUI Buttons
125
+ START_PROCESSING = 'Start Processing'
126
+ WAIT_PROCESSING = 'Please wait...'
127
+ STOP_PROCESSING = 'Halting process, please wait...'
128
+ LOADING_MODELS = 'Loading models...'
129
+
130
+ #---Messages and Logs----
131
+
132
+ MISSING_MODEL = 'missing'
133
+ MODEL_PRESENT = 'present'
134
+
135
+ ALL_STEMS = 'All Stems'
136
+ VOCAL_STEM = 'Vocals'
137
+ INST_STEM = 'Instrumental'
138
+ OTHER_STEM = 'Other'
139
+ BASS_STEM = 'Bass'
140
+ DRUM_STEM = 'Drums'
141
+ GUITAR_STEM = 'Guitar'
142
+ PIANO_STEM = 'Piano'
143
+ SYNTH_STEM = 'Synthesizer'
144
+ STRINGS_STEM = 'Strings'
145
+ WOODWINDS_STEM = 'Woodwinds'
146
+ BRASS_STEM = 'Brass'
147
+ WIND_INST_STEM = 'Wind Inst'
148
+ NO_OTHER_STEM = 'No Other'
149
+ NO_BASS_STEM = 'No Bass'
150
+ NO_DRUM_STEM = 'No Drums'
151
+ NO_GUITAR_STEM = 'No Guitar'
152
+ NO_PIANO_STEM = 'No Piano'
153
+ NO_SYNTH_STEM = 'No Synthesizer'
154
+ NO_STRINGS_STEM = 'No Strings'
155
+ NO_WOODWINDS_STEM = 'No Woodwinds'
156
+ NO_WIND_INST_STEM = 'No Wind Inst'
157
+ NO_BRASS_STEM = 'No Brass'
158
+ PRIMARY_STEM = 'Primary Stem'
159
+ SECONDARY_STEM = 'Secondary Stem'
160
+ LEAD_VOCAL_STEM = 'lead_only'
161
+ BV_VOCAL_STEM = 'backing_only'
162
+ LEAD_VOCAL_STEM_I = 'with_lead_vocals'
163
+ BV_VOCAL_STEM_I = 'with_backing_vocals'
164
+ LEAD_VOCAL_STEM_LABEL = 'Lead Vocals'
165
+ BV_VOCAL_STEM_LABEL = 'Backing Vocals'
166
+
167
+ VOCAL_STEM_ONLY = f'{VOCAL_STEM} Only'
168
+ INST_STEM_ONLY = f'{INST_STEM} Only'
169
+ PRIMARY_STEM_ONLY = f'{PRIMARY_STEM} Only'
170
+
171
+ IS_SAVE_INST_ONLY = f'save_only_inst'
172
+ IS_SAVE_VOC_ONLY = f'save_only_voc'
173
+
174
+ DEVERB_MAPPER = {'Main Vocals Only':VOCAL_STEM,
175
+ 'Lead Vocals Only':LEAD_VOCAL_STEM_LABEL,
176
+ 'Backing Vocals Only':BV_VOCAL_STEM_LABEL,
177
+ 'All Vocal Types':'ALL'}
178
+
179
+ BALANCE_VALUES = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
180
+
181
+ #Other Constants
182
+ DEMUCS_2_SOURCE = ["instrumental", "vocals"]
183
+ DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"]
184
+
185
+ DEMUCS_2_SOURCE_MAPPER = {
186
+ INST_STEM: 0,
187
+ VOCAL_STEM: 1}
188
+
189
+ DEMUCS_4_SOURCE_MAPPER = {
190
+ BASS_STEM: 0,
191
+ DRUM_STEM: 1,
192
+ OTHER_STEM: 2,
193
+ VOCAL_STEM: 3}
194
+
195
+ DEMUCS_6_SOURCE_MAPPER = {
196
+ BASS_STEM:0,
197
+ DRUM_STEM:1,
198
+ OTHER_STEM:2,
199
+ VOCAL_STEM:3,
200
+ GUITAR_STEM:4,
201
+ PIANO_STEM:5}
202
+
203
+ DEMUCS_4_SOURCE_LIST = [BASS_STEM, DRUM_STEM, OTHER_STEM, VOCAL_STEM]
204
+ DEMUCS_6_SOURCE_LIST = [BASS_STEM, DRUM_STEM, OTHER_STEM, VOCAL_STEM, GUITAR_STEM, PIANO_STEM]
205
+
206
+ DEMUCS_UVR_MODEL = 'UVR_Model'
207
+
208
+ CHOOSE_STEM_PAIR = 'Choose Stem Pair'
209
+
210
+ STEM_SET_MENU = (VOCAL_STEM,
211
+ INST_STEM,
212
+ OTHER_STEM,
213
+ BASS_STEM,
214
+ DRUM_STEM,
215
+ GUITAR_STEM,
216
+ PIANO_STEM,
217
+ SYNTH_STEM,
218
+ STRINGS_STEM,
219
+ WOODWINDS_STEM,
220
+ BRASS_STEM,
221
+ WIND_INST_STEM)
222
+
223
+ STEM_SET_MENU_ONLY = list(STEM_SET_MENU) + [OPT_SEPARATOR_SAVE, INPUT_STEM_NAME]
224
+
225
+ STEM_SET_MENU_2 = (
226
+ OTHER_STEM,
227
+ BASS_STEM,
228
+ DRUM_STEM,
229
+ GUITAR_STEM,
230
+ PIANO_STEM,
231
+ SYNTH_STEM,
232
+ STRINGS_STEM,
233
+ WOODWINDS_STEM,
234
+ BRASS_STEM,
235
+ WIND_INST_STEM,
236
+ "Noise",
237
+ "Reverb")
238
+
239
+ STEM_PAIR_MAPPER = {
240
+ VOCAL_STEM: INST_STEM,
241
+ INST_STEM: VOCAL_STEM,
242
+ LEAD_VOCAL_STEM: BV_VOCAL_STEM,
243
+ BV_VOCAL_STEM: LEAD_VOCAL_STEM,
244
+ PRIMARY_STEM: SECONDARY_STEM}
245
+
246
+ STEM_PAIR_MAPPER_FULL = {
247
+ VOCAL_STEM: INST_STEM,
248
+ INST_STEM: VOCAL_STEM,
249
+ OTHER_STEM: NO_OTHER_STEM,
250
+ BASS_STEM: NO_BASS_STEM,
251
+ DRUM_STEM: NO_DRUM_STEM,
252
+ GUITAR_STEM: NO_GUITAR_STEM,
253
+ PIANO_STEM: NO_PIANO_STEM,
254
+ SYNTH_STEM: NO_SYNTH_STEM,
255
+ STRINGS_STEM: NO_STRINGS_STEM,
256
+ WOODWINDS_STEM: NO_WOODWINDS_STEM,
257
+ BRASS_STEM: NO_BRASS_STEM,
258
+ WIND_INST_STEM: NO_WIND_INST_STEM,
259
+ NO_OTHER_STEM: OTHER_STEM,
260
+ NO_BASS_STEM: BASS_STEM,
261
+ NO_DRUM_STEM: DRUM_STEM,
262
+ NO_GUITAR_STEM: GUITAR_STEM,
263
+ NO_PIANO_STEM: PIANO_STEM,
264
+ NO_SYNTH_STEM: SYNTH_STEM,
265
+ NO_STRINGS_STEM: STRINGS_STEM,
266
+ NO_WOODWINDS_STEM: WOODWINDS_STEM,
267
+ NO_BRASS_STEM: BRASS_STEM,
268
+ NO_WIND_INST_STEM: WIND_INST_STEM,
269
+ PRIMARY_STEM: SECONDARY_STEM}
270
+
271
+ NO_STEM = "No "
272
+
273
+ NON_ACCOM_STEMS = (
274
+ VOCAL_STEM,
275
+ OTHER_STEM,
276
+ BASS_STEM,
277
+ DRUM_STEM,
278
+ GUITAR_STEM,
279
+ PIANO_STEM,
280
+ SYNTH_STEM,
281
+ STRINGS_STEM,
282
+ WOODWINDS_STEM,
283
+ BRASS_STEM,
284
+ WIND_INST_STEM)
285
+
286
+ MDX_NET_FREQ_CUT = [VOCAL_STEM, INST_STEM]
287
+
288
+ DEMUCS_4_STEM_OPTIONS = (ALL_STEMS, VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM)
289
+ DEMUCS_6_STEM_OPTIONS = (ALL_STEMS, VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM)
290
+ DEMUCS_2_STEM_OPTIONS = (VOCAL_STEM, INST_STEM)
291
+ DEMUCS_4_STEM_CHECK = (OTHER_STEM, BASS_STEM, DRUM_STEM)
292
+
293
+ #Menu Dropdowns
294
+
295
+ VOCAL_PAIR = f'{VOCAL_STEM}/{INST_STEM}'
296
+ INST_PAIR = f'{INST_STEM}/{VOCAL_STEM}'
297
+ OTHER_PAIR = f'{OTHER_STEM}/{NO_OTHER_STEM}'
298
+ DRUM_PAIR = f'{DRUM_STEM}/{NO_DRUM_STEM}'
299
+ BASS_PAIR = f'{BASS_STEM}/{NO_BASS_STEM}'
300
+ FOUR_STEM_ENSEMBLE = '4 Stem Ensemble'
301
+ MULTI_STEM_ENSEMBLE = 'Multi-stem Ensemble'
302
+
303
+ ENSEMBLE_MAIN_STEM = (CHOOSE_STEM_PAIR, VOCAL_PAIR, OTHER_PAIR, DRUM_PAIR, BASS_PAIR, FOUR_STEM_ENSEMBLE, MULTI_STEM_ENSEMBLE)
304
+
305
+ MIN_SPEC = 'Min Spec'
306
+ MAX_SPEC = 'Max Spec'
307
+ AUDIO_AVERAGE = 'Average'
308
+
309
+ MAX_MIN = f'{MAX_SPEC}/{MIN_SPEC}'
310
+ MAX_MAX = f'{MAX_SPEC}/{MAX_SPEC}'
311
+ MAX_AVE = f'{MAX_SPEC}/{AUDIO_AVERAGE}'
312
+ MIN_MAX = f'{MIN_SPEC}/{MAX_SPEC}'
313
+ MIN_MIX = f'{MIN_SPEC}/{MIN_SPEC}'
314
+ MIN_AVE = f'{MIN_SPEC}/{AUDIO_AVERAGE}'
315
+ AVE_MAX = f'{AUDIO_AVERAGE}/{MAX_SPEC}'
316
+ AVE_MIN = f'{AUDIO_AVERAGE}/{MIN_SPEC}'
317
+ AVE_AVE = f'{AUDIO_AVERAGE}/{AUDIO_AVERAGE}'
318
+
319
+ ENSEMBLE_TYPE = (MAX_MIN, MAX_MAX, MAX_AVE, MIN_MAX, MIN_MIX, MIN_AVE, AVE_MAX, AVE_MIN, AVE_AVE)
320
+ ENSEMBLE_TYPE_4_STEM = (MAX_SPEC, MIN_SPEC, AUDIO_AVERAGE)
321
+
322
+ BATCH_MODE = 'Batch Mode'
323
+ BETA_VERSION = 'BETA'
324
+ DEF_OPT = 'Default'
325
+ USER_INPUT = "User Input"
326
+ OPT_SEPARATOR = '─'*65
327
+
328
+ CHUNKS = (AUTO_SELECT, '1', '5', '10', '15', '20',
329
+ '25', '30', '35', '40', '45', '50',
330
+ '55', '60', '65', '70', '75', '80',
331
+ '85', '90', '95', 'Full')
332
+
333
+ BATCH_SIZE = (DEF_OPT, '2', '3', '4', '5',
334
+ '6', '7', '8', '9', '10')
335
+
336
+ VOL_COMPENSATION = (AUTO_SELECT, '1.035', '1.08')
337
+
338
+ MARGIN_SIZE = ('44100', '22050', '11025')
339
+
340
+ AUDIO_TOOLS = 'Audio Tools'
341
+
342
+ MANUAL_ENSEMBLE = 'Manual Ensemble'
343
+ TIME_STRETCH = 'Time Stretch'
344
+ CHANGE_PITCH = 'Change Pitch'
345
+ ALIGN_INPUTS = 'Align Inputs'
346
+ MATCH_INPUTS = 'Matchering'
347
+ COMBINE_INPUTS = 'Combine Inputs'
348
+
349
+ if OPERATING_SYSTEM == 'Windows' or OPERATING_SYSTEM == 'Darwin':
350
+ AUDIO_TOOL_OPTIONS = (MANUAL_ENSEMBLE, TIME_STRETCH, CHANGE_PITCH, ALIGN_INPUTS, MATCH_INPUTS)
351
+ else:
352
+ AUDIO_TOOL_OPTIONS = (MANUAL_ENSEMBLE, ALIGN_INPUTS, MATCH_INPUTS)
353
+
354
+ MANUAL_ENSEMBLE_OPTIONS = (MIN_SPEC, MAX_SPEC, AUDIO_AVERAGE, COMBINE_INPUTS)
355
+
356
+ PROCESS_METHODS = (VR_ARCH_PM, MDX_ARCH_TYPE, DEMUCS_ARCH_TYPE, ENSEMBLE_MODE, AUDIO_TOOLS)
357
+
358
+ DEMUCS_SEGMENTS = (DEF_OPT, '1', '5', '10', '15', '20',
359
+ '25', '30', '35', '40', '45', '50',
360
+ '55', '60', '65', '70', '75', '80',
361
+ '85', '90', '95', '100')
362
+
363
+ DEMUCS_SHIFTS = (0, 1, 2, 3, 4, 5,
364
+ 6, 7, 8, 9, 10, 11,
365
+ 12, 13, 14, 15, 16, 17,
366
+ 18, 19, 20)
367
+ SEMI_DEF = ['0']
368
+ SEMITONE_SEL = (-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12)
369
+
370
+ NOUT_SEL = (8, 16, 32, 48, 64)
371
+ NOUT_LSTM_SEL = (64, 128)
372
+
373
+ DEMUCS_OVERLAP = (0.25, 0.50, 0.75, 0.99)
374
+ MDX_OVERLAP = (DEF_OPT, 0.25, 0.50, 0.75, 0.99)
375
+ MDX23_OVERLAP = range(2, 51)
376
+ VR_AGGRESSION = range(0, 51)
377
+
378
+ TIME_WINDOW_MAPPER = {
379
+ "None": None,
380
+ "1": [0.0625],
381
+ "2": [0.125],
382
+ "3": [0.25],
383
+ "4": [0.5],
384
+ "5": [0.75],
385
+ "6": [1],
386
+ "7": [2],
387
+ "Shifts: Low": [0.0625, 0.5],
388
+ "Shifts: Medium": [0.0625, 0.125, 0.5],
389
+ "Shifts: High": [0.0625, 0.125, 0.25, 0.5]
390
+ #"Shifts: Very High": [0.0625, 0.125, 0.25, 0.5, 0.75, 1],
391
+ }
392
+
393
+ INTRO_MAPPER = {
394
+ "Default": [10],
395
+ "1": [8],
396
+ "2": [6],
397
+ "3": [4],
398
+ "4": [2],
399
+ "Shifts: Low": [1, 10],
400
+ "Shifts: Medium": [1, 10, 8],
401
+ "Shifts: High": [1, 10, 8, 6, 4]
402
+ }
403
+
404
+ VOLUME_MAPPER = {
405
+ "None": (0, [0]),
406
+ "Low": (-4, range(0, 8)),
407
+ "Medium": (-6, range(0, 12)),
408
+ "High": (-6, [x * 0.5 for x in range(0, 25)]),
409
+ "Very High": (-10, [x * 0.5 for x in range(0, 41)])}
410
+ #"Max": (-10, [x * 0.3 for x in range(0, int(20 / 0.3) + 1)])}
411
+
412
+ PHASE_MAPPER = {
413
+ "None": [0],
414
+ "Shifts Low": [0, 180],
415
+ "Shifts Medium": [0],
416
+ "Shifts High": [0],
417
+ "Shifts Very High": [0],}
418
+
419
+ NONE_P = "None"
420
+ VLOW_P = "Shifts: Very Low"
421
+ LOW_P = "Shifts: Low"
422
+ MED_P = "Shifts: Medium"
423
+ HIGH_P = "Shifts: High"
424
+ VHIGH_P = "Shifts: Very High"
425
+ VMAX_P = "Shifts: Maximum"
426
+
427
+ PHASE_SHIFTS_OPT = {
428
+ NONE_P:190,
429
+ VLOW_P:180,
430
+ LOW_P:90,
431
+ MED_P:45,
432
+ HIGH_P:20,
433
+ VHIGH_P:10,
434
+ VMAX_P:1,}
435
+
436
+ VR_WINDOW = ('320', '512','1024')
437
+ VR_CROP = ('256', '512', '1024')
438
+ POST_PROCESSES_THREASHOLD_VALUES = ('0.1', '0.2', '0.3')
439
+
440
+ MDX_POP_PRO = ('MDX-NET_Noise_Profile_14_kHz', 'MDX-NET_Noise_Profile_17_kHz', 'MDX-NET_Noise_Profile_Full_Band')
441
+ MDX_POP_STEMS = ('Vocals', 'Instrumental', 'Other', 'Drums', 'Bass')
442
+ MDX_POP_NFFT = ('4096', '5120', '6144', '7680', '8192', '16384')
443
+ MDX_POP_DIMF = ('2048', '3072', '4096')
444
+ DENOISE_NONE, DENOISE_S, DENOISE_M = 'None', 'Standard', 'Denoise Model'
445
+ MDX_DENOISE_OPTION = [DENOISE_NONE, DENOISE_S, DENOISE_M]
446
+ MDX_SEGMENTS = list(range(32, 4000+1, 32))
447
+
448
+ SAVE_ENSEMBLE = 'Save Ensemble'
449
+ CLEAR_ENSEMBLE = 'Clear Selection(s)'
450
+ MENU_SEPARATOR = 35*'•'
451
+ CHOOSE_ENSEMBLE_OPTION = 'Choose Option'
452
+ ALL_TYPES = 'ALL'
453
+ INVALID_ENTRY = 'Invalid Input, Please Try Again'
454
+ ENSEMBLE_INPUT_RULE = '1. Only letters, numbers, spaces, and dashes allowed.\n2. No dashes or spaces at the start or end of input.'
455
+ STEM_INPUT_RULE = '1. Only words with no spaces are allowed.\n2. No spaces, numbers, or special characters.'
456
+
457
+ ENSEMBLE_OPTIONS = [OPT_SEPARATOR_SAVE, SAVE_ENSEMBLE, CLEAR_ENSEMBLE]
458
+ ENSEMBLE_CHECK = 'ensemble check'
459
+ KARAOKEE_CHECK = 'kara check'
460
+
461
+ AUTO_PHASE = "Automatic"
462
+ POSITIVE_PHASE = "Positive Phase"
463
+ NEGATIVE_PHASE = "Negative Phase"
464
+ OFF_PHASE = "Native Phase"
465
+
466
+ ALIGN_PHASE_OPTIONS = [AUTO_PHASE, POSITIVE_PHASE, NEGATIVE_PHASE, OFF_PHASE]
467
+
468
+ SELECT_SAVED_ENSEMBLE = 'Select Saved Ensemble'
469
+ SELECT_SAVED_SETTING = 'Select Saved Setting'
470
+ ENSEMBLE_OPTION = "Ensemble Customization Options"
471
+ MDX_OPTION = "Advanced MDX-Net Options"
472
+ DEMUCS_OPTION = "Advanced Demucs Options"
473
+ VR_OPTION = "Advanced VR Options"
474
+ HELP_OPTION = "Open Information Guide"
475
+ ERROR_OPTION = "Open Error Log"
476
+ VERIFY_BEGIN = 'Verifying file '
477
+ SAMPLE_BEGIN = 'Creating Sample '
478
+ MODEL_MISSING_CHECK = 'Model Missing:'
479
+ OPTION_LIST = [VR_OPTION, MDX_OPTION, DEMUCS_OPTION, ENSEMBLE_OPTION, ALIGNMENT_TOOL, HELP_OPTION, ERROR_OPTION]
480
+
481
+ #Menu Strings
482
+ VR_MENU ='VR Menu'
483
+ DEMUCS_MENU ='Demucs Menu'
484
+ MDX_MENU ='MDX-Net Menu'
485
+ ENSEMBLE_MENU ='Ensemble Menu'
486
+ HELP_MENU ='Help Menu'
487
+ ERROR_MENU ='Error Log'
488
+ INPUTS_MENU ='Inputs Menu'
489
+ ALIGN_MENU ='Align Menu'
490
+
491
+ # Audio Player
492
+ PLAYING_SONG = ": Playing"
493
+ PAUSE_SONG = ": Paused"
494
+ STOP_SONG = ": Stopped"
495
+
496
+ SELECTED_VER = 'Selected'
497
+ DETECTED_VER = 'Detected'
498
+
499
+ SAMPLE_MODE_CHECKBOX = lambda v:f'Sample Mode ({v}s)'
500
+ REMOVED_FILES = lambda r, e:f'Audio Input Verification Report:\n\nRemoved Files:\n\n{r}\n\nError Details:\n\n{e}'
501
+ ADVANCED_SETTINGS = (ENSEMBLE_OPTION, MDX_OPTION, DEMUCS_OPTION, VR_OPTION, HELP_OPTION, ERROR_OPTION)
502
+
503
+ WAV = 'WAV'
504
+ FLAC = 'FLAC'
505
+ MP3 = 'MP3'
506
+
507
+ MP3_BIT_RATES = ('96k', '128k', '160k', '224k', '256k', '320k')
508
+ WAV_TYPE = ('PCM_U8', 'PCM_16', 'PCM_24', 'PCM_32', '32-bit Float', '64-bit Float')
509
+ GPU_DEVICE_NUM_OPTS = (DEFAULT, '0', '1', '2', '3', '4', '5', '6', '7', '8')
510
+
511
+ SELECT_SAVED_SET = 'Choose Option'
512
+ SAVE_SETTINGS = 'Save Current Settings'
513
+ RESET_TO_DEFAULT = 'Reset to Default'
514
+ RESET_FULL_TO_DEFAULT = 'Reset to Default'
515
+ RESET_PM_TO_DEFAULT = 'Reset All Application Settings to Default'
516
+
517
+ SAVE_SET_OPTIONS = [OPT_SEPARATOR_SAVE, SAVE_SETTINGS, RESET_TO_DEFAULT]
518
+
519
+ TIME_PITCH = ('1.0', '2.0', '3.0', '4.0')
520
+ TIME_TEXT = '_time_stretched'
521
+ PITCH_TEXT = '_pitch_shifted'
522
+
523
+ #RegEx Input Validation
524
+ REG_PITCH = r'^[-+]?(1[0]|[0-9]([.][0-9]*)?)$'
525
+ REG_TIME = r'^[+]?(1[0]|[0-9]([.][0-9]*)?)$'
526
+ REG_COMPENSATION = r'\b^(1[0]|[0-9]([.][0-9]*)?|Auto|None)$\b'
527
+ REG_THES_POSTPORCESS = r'\b^([0]([.][0-9]{0,6})?)$\b'
528
+ REG_CHUNKS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Auto|Full)$\b'
529
+ REG_CHUNKS_DEMUCS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Auto|Full)$\b'
530
+ REG_MARGIN = r'\b^[0-9]*$\b'
531
+ REG_SEGMENTS = r'\b^(200|1[0-9][0-9]|[1-9][0-9]?|Default)$\b'
532
+ REG_SAVE_INPUT = r'\b^([a-zA-Z0-9 -]{0,25})$\b'
533
+ REG_INPUT_STEM_NAME = r'^(Wind Inst|[a-zA-Z]{1,25})$'
534
+ REG_SEMITONES = r'^-?(20\.00|[01]?\d(\.\d{1,2})?|20)$'
535
+ REG_AGGRESSION = r'^[-+]?[0-9]\d*?$'
536
+ REG_WINDOW = r'\b^[0-9]{0,4}$\b'
537
+ REG_SHIFTS = r'\b^[0-9]*$\b'
538
+ REG_BATCHES = r'\b^([0-9]*?|Default)$\b'
539
+ REG_OVERLAP = r'\b^([0]([.][0-9]{0,6})?|Default)$\b'#r"(Default|[0-9]+(\.[0-9]+)?)"#
540
+ REG_OVERLAP23 = r'\b^([1][0-9]|[2-9][0-9]*|Default)$\b'#r'\b^([2-9][0-9]*?|Default)$\b'
541
+ REG_MDX_SEG = r'\b(?:' + '|'.join([str(num) for num in range(32, 1000001, 32)]) + r')\b'
542
+ REG_ALIGN = r'^[-+]?[0-9]\d*?$'
543
+ REG_VOL_COMP = r'^\d+\.\d{1,9}$'
544
+
545
+ # Sub Menu
546
+ VR_ARCH_SETTING_LOAD = 'Load for VR Arch'
547
+ MDX_SETTING_LOAD = 'Load for MDX-Net'
548
+ DEMUCS_SETTING_LOAD = 'Load for Demucs'
549
+ ALL_ARCH_SETTING_LOAD = 'Load for Full Application'
550
+
551
+ # Mappers
552
+
553
+ DEFAULT_DATA = {
554
+ 'chosen_process_method': MDX_ARCH_TYPE,
555
+ 'vr_model': CHOOSE_MODEL,
556
+ 'aggression_setting': 5,
557
+ 'window_size': 512,
558
+ 'mdx_segment_size': 256,
559
+ 'batch_size': DEF_OPT,
560
+ 'crop_size': 256,
561
+ 'is_tta': False,
562
+ 'is_output_image': False,
563
+ 'is_post_process': False,
564
+ 'is_high_end_process': False,
565
+ 'post_process_threshold': 0.2,
566
+ 'vr_voc_inst_secondary_model': NO_MODEL,
567
+ 'vr_other_secondary_model': NO_MODEL,
568
+ 'vr_bass_secondary_model': NO_MODEL,
569
+ 'vr_drums_secondary_model': NO_MODEL,
570
+ 'vr_is_secondary_model_activate': False,
571
+ 'vr_voc_inst_secondary_model_scale': 0.9,
572
+ 'vr_other_secondary_model_scale': 0.7,
573
+ 'vr_bass_secondary_model_scale': 0.5,
574
+ 'vr_drums_secondary_model_scale': 0.5,
575
+ 'demucs_model': CHOOSE_MODEL,
576
+ 'segment': DEMUCS_SEGMENTS[0],
577
+ 'overlap': DEMUCS_OVERLAP[0],
578
+ 'overlap_mdx': MDX_OVERLAP[0],
579
+ 'overlap_mdx23': '8',
580
+ 'shifts': 2,
581
+ 'chunks_demucs': CHUNKS[0],
582
+ 'margin_demucs': 44100,
583
+ 'is_chunk_demucs': False,
584
+ 'is_chunk_mdxnet': False,
585
+ 'is_primary_stem_only_Demucs': False,
586
+ 'is_secondary_stem_only_Demucs': False,
587
+ 'is_split_mode': True,
588
+ 'is_demucs_combine_stems': True,#
589
+ 'is_mdx23_combine_stems': True,#
590
+ 'demucs_voc_inst_secondary_model': NO_MODEL,
591
+ 'demucs_other_secondary_model': NO_MODEL,
592
+ 'demucs_bass_secondary_model': NO_MODEL,
593
+ 'demucs_drums_secondary_model': NO_MODEL,
594
+ 'demucs_is_secondary_model_activate': False,
595
+ 'demucs_voc_inst_secondary_model_scale': 0.9,
596
+ 'demucs_other_secondary_model_scale': 0.7,
597
+ 'demucs_bass_secondary_model_scale': 0.5,
598
+ 'demucs_drums_secondary_model_scale': 0.5,
599
+ 'demucs_stems': ALL_STEMS,
600
+ 'demucs_pre_proc_model': NO_MODEL,
601
+ 'is_demucs_pre_proc_model_activate': False,
602
+ 'is_demucs_pre_proc_model_inst_mix': False,
603
+ 'mdx_net_model': CHOOSE_MODEL,
604
+ 'chunks': CHUNKS[0],
605
+ 'margin': 44100,
606
+ 'compensate': AUTO_SELECT,
607
+ 'is_denoise': False,#
608
+ 'denoise_option': 'None',#
609
+ 'phase_option': AUTO_PHASE,
610
+ 'phase_shifts': NONE_P,#
611
+ 'is_save_align': False,#,
612
+ 'is_match_frequency_pitch': True,#
613
+ 'is_match_silence': True,#
614
+ 'is_spec_match': False,#
615
+ 'is_mdx_c_seg_def': False,
616
+ 'is_invert_spec': False, #
617
+ 'is_deverb_vocals': False, #
618
+ 'deverb_vocal_opt': 'Main Vocals Only', #
619
+ 'voc_split_save_opt': 'Lead Only', #
620
+ 'is_mixer_mode': False,
621
+ 'mdx_batch_size': DEF_OPT,
622
+ 'mdx_voc_inst_secondary_model': NO_MODEL,
623
+ 'mdx_other_secondary_model': NO_MODEL,
624
+ 'mdx_bass_secondary_model': NO_MODEL,
625
+ 'mdx_drums_secondary_model': NO_MODEL,
626
+ 'mdx_is_secondary_model_activate': False,
627
+ 'mdx_voc_inst_secondary_model_scale': 0.9,
628
+ 'mdx_other_secondary_model_scale': 0.7,
629
+ 'mdx_bass_secondary_model_scale': 0.5,
630
+ 'mdx_drums_secondary_model_scale': 0.5,
631
+ 'mdx_stems': ALL_STEMS,
632
+ 'is_save_all_outputs_ensemble': True,
633
+ 'is_append_ensemble_name': False,
634
+ 'chosen_audio_tool': AUDIO_TOOL_OPTIONS[0],
635
+ 'choose_algorithm': MANUAL_ENSEMBLE_OPTIONS[0],
636
+ 'time_stretch_rate': 2.0,
637
+ 'pitch_rate': 2.0,
638
+ 'is_time_correction': True,
639
+ 'is_gpu_conversion': False,
640
+ 'is_primary_stem_only': False,
641
+ 'is_secondary_stem_only': False,
642
+ 'is_testing_audio': False,#
643
+ 'is_auto_update_model_params': True,#
644
+ 'is_add_model_name': False,
645
+ 'is_accept_any_input': False,
646
+ 'is_task_complete': False,
647
+ 'is_normalization': False,
648
+ 'is_use_opencl': False,
649
+ 'is_wav_ensemble': False,
650
+ 'is_create_model_folder': False,
651
+ 'mp3_bit_set': '320k',#
652
+ 'semitone_shift': '0',#
653
+ 'save_format': WAV,
654
+ 'wav_type_set': 'PCM_16',
655
+ 'device_set': DEFAULT,
656
+ 'user_code': '',
657
+ 'export_path': '',
658
+ 'input_paths': [],
659
+ 'lastDir': None,
660
+ 'time_window': "3",
661
+ 'intro_analysis': DEFAULT,
662
+ 'db_analysis': "Medium",
663
+ 'fileOneEntry': '',
664
+ 'fileOneEntry_Full': '',
665
+ 'fileTwoEntry': '',
666
+ 'fileTwoEntry_Full': '',
667
+ 'DualBatch_inputPaths': [],
668
+ 'model_hash_table': {},
669
+ 'help_hints_var': True,
670
+ 'set_vocal_splitter': NO_MODEL,
671
+ 'is_set_vocal_splitter': False,#
672
+ 'is_save_inst_set_vocal_splitter': False,#
673
+ 'model_sample_mode': False,
674
+ 'model_sample_mode_duration': 30
675
+ }
676
+
677
+ SETTING_CHECK = ('vr_model',
678
+ 'aggression_setting',
679
+ 'window_size',
680
+ 'mdx_segment_size',
681
+ 'batch_size',
682
+ 'crop_size',
683
+ 'is_tta',
684
+ 'is_output_image',
685
+ 'is_post_process',
686
+ 'is_high_end_process',
687
+ 'post_process_threshold',
688
+ 'vr_voc_inst_secondary_model',
689
+ 'vr_other_secondary_model',
690
+ 'vr_bass_secondary_model',
691
+ 'vr_drums_secondary_model',
692
+ 'vr_is_secondary_model_activate',
693
+ 'vr_voc_inst_secondary_model_scale',
694
+ 'vr_other_secondary_model_scale',
695
+ 'vr_bass_secondary_model_scale',
696
+ 'vr_drums_secondary_model_scale',
697
+ 'demucs_model',
698
+ 'segment',
699
+ 'overlap',
700
+ 'overlap_mdx',
701
+ 'shifts',
702
+ 'chunks_demucs',
703
+ 'margin_demucs',
704
+ 'is_chunk_demucs',
705
+ 'is_primary_stem_only_Demucs',
706
+ 'is_secondary_stem_only_Demucs',
707
+ 'is_split_mode',
708
+ 'is_demucs_combine_stems',#
709
+ 'is_mdx23_combine_stems',#
710
+ 'demucs_voc_inst_secondary_model',
711
+ 'demucs_other_secondary_model',
712
+ 'demucs_bass_secondary_model',
713
+ 'demucs_drums_secondary_model',
714
+ 'demucs_is_secondary_model_activate',
715
+ 'demucs_voc_inst_secondary_model_scale',
716
+ 'demucs_other_secondary_model_scale',
717
+ 'demucs_bass_secondary_model_scale',
718
+ 'demucs_drums_secondary_model_scale',
719
+ 'demucs_stems',
720
+ 'mdx_net_model',
721
+ 'chunks',
722
+ 'margin',
723
+ 'compensate',
724
+ 'is_denoise',#
725
+ 'denoise_option',#
726
+ 'phase_option',#
727
+ 'phase_shifts',#
728
+ 'is_save_align',#,
729
+ 'is_match_silence',
730
+ 'is_spec_match',#,
731
+ 'is_match_frequency_pitch',#
732
+ 'is_mdx_c_seg_def',
733
+ 'is_invert_spec',#
734
+ 'is_deverb_vocals',#
735
+ 'deverb_vocal_opt',#
736
+ 'voc_split_save_opt',#
737
+ 'mdx_batch_size',
738
+ 'mdx_voc_inst_secondary_model',
739
+ 'mdx_other_secondary_model',
740
+ 'mdx_bass_secondary_model',
741
+ 'mdx_drums_secondary_model',
742
+ 'mdx_is_secondary_model_activate',
743
+ 'mdx_voc_inst_secondary_model_scale',
744
+ 'mdx_other_secondary_model_scale',
745
+ 'mdx_bass_secondary_model_scale',
746
+ 'mdx_drums_secondary_model_scale',
747
+ 'is_save_all_outputs_ensemble',
748
+ 'is_append_ensemble_name',
749
+ 'chosen_audio_tool',
750
+ 'choose_algorithm',
751
+ 'time_stretch_rate',
752
+ 'pitch_rate',
753
+ 'is_time_correction',
754
+ 'is_primary_stem_only',
755
+ 'is_secondary_stem_only',
756
+ 'is_testing_audio',#
757
+ 'is_auto_update_model_params',#
758
+ 'is_add_model_name',
759
+ "is_accept_any_input",
760
+ 'is_task_complete',
761
+ 'is_create_model_folder',
762
+ 'mp3_bit_set',#
763
+ 'semitone_shift',#
764
+ 'save_format',
765
+ 'wav_type_set',
766
+ 'device_set',
767
+ 'user_code',
768
+ 'is_gpu_conversion',
769
+ 'is_normalization',
770
+ 'is_use_opencl',
771
+ 'is_wav_ensemble',
772
+ 'help_hints_var',
773
+ 'set_vocal_splitter',
774
+ 'is_set_vocal_splitter',#
775
+ 'is_save_inst_set_vocal_splitter',#
776
+ 'model_sample_mode',
777
+ 'model_sample_mode_duration',
778
+ 'time_window',
779
+ 'intro_analysis',
780
+ 'db_analysis',
781
+ 'fileOneEntry',
782
+ 'fileOneEntry_Full',
783
+ 'fileTwoEntry',
784
+ 'fileTwoEntry_Full',
785
+ 'DualBatch_inputPaths'
786
+ )
787
+
788
+ NEW_LINES = "\n\n"
789
+ NEW_LINE = "\n"
790
+ NO_LINE = ''
791
+
792
+ FFMPEG_EXT = (".aac", ".aiff", ".alac" ,".flac", ".FLAC", ".mov", ".mp4", ".MP4",
793
+ ".m4a", ".M4A", ".mp2", ".mp3", "MP3", ".mpc", ".mpc8",
794
+ ".mpeg", ".ogg", ".OGG", ".tta", ".wav", ".wave", ".WAV", ".WAVE", ".wma", ".webm", ".eac3", ".mkv", ".opus", ".OPUS")
795
+
796
+ FFMPEG_MORE_EXT = (".aa", ".aac", ".ac3", ".aiff", ".alac", ".avi", ".f4v",".flac", ".flic", ".flv",
797
+ ".m4v",".mlv", ".mov", ".mp4", ".m4a", ".mp2", ".mp3", ".mp4", ".mpc", ".mpc8",
798
+ ".mpeg", ".ogg", ".tta", ".tty", ".vcd", ".wav", ".wma")
799
+ ANY_EXT = ""
800
+
801
+ # Secondary Menu Constants
802
+
803
+ VOCAL_PAIR_PLACEMENT = 1, 2, 3, 4
804
+ OTHER_PAIR_PLACEMENT = 5, 6, 7, 8
805
+ BASS_PAIR_PLACEMENT = 9, 10, 11, 12
806
+ DRUMS_PAIR_PLACEMENT = 13, 14, 15, 16
807
+
808
+ # Drag n Drop String Checks
809
+
810
+ DOUBLE_BRACKET = "} {"
811
+ RIGHT_BRACKET = "}"
812
+ LEFT_BRACKET = "{"
813
+ #DND CONSTS
814
+
815
+ MAC_DND_CHECK = ('/Users/',
816
+ '/Applications/',
817
+ '/Library/',
818
+ '/System/')
819
+ LINUX_DND_CHECK = ('/home/',
820
+ '/usr/')
821
+ WINDOWS_DND_CHECK = ('A:', 'B:', 'C:', 'D:', 'E:', 'F:', 'G:', 'H:', 'I:', 'J:', 'K:', 'L:', 'M:', 'N:', 'O:', 'P:', 'Q:', 'R:', 'S:', 'T:', 'U:', 'V:', 'W:', 'X:', 'Y:', 'Z:')
822
+
823
+ WOOD_INST_MODEL_HASH = '0ec76fd9e65f81d8b4fbd13af4826ed8'
824
+ WOOD_INST_PARAMS = {
825
+ "vr_model_param": "4band_v3",
826
+ "primary_stem": NO_WIND_INST_STEM
827
+ }
828
+
829
+ READ_ONLY = 'readonly'
830
+
831
+ FILE_1 = 'file1'
832
+ FILE_2 = 'file2'
833
+
834
+ FILE_1_LB = 'file1_lb'
835
+ FILE_2_LB = 'file1_2b'
836
+ BATCH_MODE_DUAL = " : Batch Mode"
837
+
838
+ CODEC_DICT = {
839
+ 'PCM_U8': {"sample_width": 1, "codec": None}, # 8-bit unsigned PCM
840
+ 'PCM_16': {"sample_width": 2, "codec": None}, # 16-bit signed PCM
841
+ 'PCM_24': {"sample_width": 3, "codec": None}, # 24-bit signed PCM
842
+ 'PCM_32': {"sample_width": 4, "codec": None}, # 32-bit signed PCM
843
+ 'FLOAT32': {"sample_width": None, "codec": "pcm_f32le"}, # 32-bit float
844
+ 'FLOAT64': {"sample_width": None, "codec": "pcm_f64le"} # 64-bit float
845
+ }
846
+
847
+
848
+ # Manual Downloads
849
+ VR_PLACEMENT_TEXT = 'Place models in \"models/VR_Models\" directory.'
850
+ MDX_PLACEMENT_TEXT = 'Place models in \"models/MDX_Net_Models\" directory.'
851
+ DEMUCS_PLACEMENT_TEXT = 'Place models in \"models/Demucs_Models\" directory.'
852
+ DEMUCS_V3_V4_PLACEMENT_TEXT = 'Place items in \"models/Demucs_Models/v3_v4_repo\" directory.'
853
+ MDX_23_NAME = "MDX23C Model"
854
+
855
+ # Liscense info
856
+ if OPERATING_SYSTEM=="Darwin":
857
+ is_macos = True
858
+ LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running macOS Catalina and above.\n' +\
859
+ '• Application functionality for systems running macOS Mojave or lower is not guaranteed.\n' +\
860
+ '• Application functionality for older or budget Mac systems is not guaranteed.\n\n'
861
+ elif OPERATING_SYSTEM=="Linux":
862
+ LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running Linux Ubuntu 18.04+.\n' +\
863
+ '• Application functionality for systems running other Linux platforms is not guaranteed.\n' +\
864
+ '• Application functionality for older or budget systems is not guaranteed.\n\n'
865
+ elif OPERATING_SYSTEM=="Windows":
866
+ LICENSE_OS_SPECIFIC_TEXT = '• This application is intended for those running Windows 10 or higher.\n' +\
867
+ '• Application functionality for systems running Windows 7 or lower is not guaranteed.\n' +\
868
+ '• Application functionality for Intel Pentium & Celeron CPUs systems is not guaranteed.\n\n'
869
+
870
+ LICENSE_TEXT = lambda a, p:f'Current Application Version: Ultimate Vocal Remover {a}\n' +\
871
+ f'Current Patch Version: {p}\n\n' +\
872
+ 'Copyright (c) 2022 Ultimate Vocal Remover\n\n' +\
873
+ 'UVR is free and open-source, but MIT licensed. Please credit us if you use our\n' +\
874
+ f'models or code for projects unrelated to UVR.\n\n{LICENSE_OS_SPECIFIC_TEXT}' +\
875
+ 'This bundle contains the UVR interface, Python, PyTorch, and other\n' +\
876
+ 'dependencies needed to run the application effectively.\n\n' +\
877
+ 'Website Links: This application, System or Service(s) may contain links to\n' +\
878
+ 'other websites and downloads, and they are solely provided to you as an\n' +\
879
+ 'additional convenience. You understand and acknowledge that by clicking\n' +\
880
+ 'or activating such links you are accessing a site or service outside of\n' +\
881
+ 'this application, and that we do not screen, review, approve, or otherwise\n' +\
882
+ 'endorse any content or information contained in these linked websites.\n' +\
883
+ 'You acknowledge and agree that we, our affiliates and partners are not\n' +\
884
+ 'responsible for the contents of any of these linked websites, including\n' +\
885
+ 'the accuracy or availability of information provided by the linked websites,\n' +\
886
+ 'and we make no representations or warranties regarding your use of\n' +\
887
+ 'the linked websites.\n\n' +\
888
+ 'This application is MIT Licensed\n\n' +\
889
+ 'Permission is hereby granted, free of charge, to any person obtaining a copy\n' +\
890
+ 'of this software and associated documentation files (the "Software"), to deal\n' +\
891
+ 'in the Software without restriction, including without limitation the rights\n' +\
892
+ 'to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n' +\
893
+ 'copies of the Software, and to permit persons to whom the Software is\n' +\
894
+ 'furnished to do so, subject to the following conditions:\n\n' +\
895
+ 'The above copyright notice and this permission notice shall be included in all\n' +\
896
+ 'copies or substantial portions of the Software.\n\n' +\
897
+ 'THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n' +\
898
+ 'IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n' +\
899
+ 'FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n' +\
900
+ 'AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n' +\
901
+ 'LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n' +\
902
+ 'OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n' +\
903
+ 'SOFTWARE.'
904
+
905
+ # Message Box Text
906
+ INVALID_INPUT = 'Invalid Input', 'The input is invalid.\n\nPlease verify the input still exists or is valid and try again.'
907
+ INVALID_EXPORT = 'Invalid Export Directory', 'You have selected an invalid export directory.\n\nPlease make sure the selected directory still exists.'
908
+ INVALID_ENSEMBLE = 'Not Enough Models', 'You must select 2 or more models to run ensemble.'
909
+ INVALID_MODEL = 'No Model Chosen', 'You must select an model to continue.'
910
+ MISSING_MODEL = 'Model Missing', 'The selected model is missing or not valid.'
911
+ ERROR_OCCURED = 'Error Occured', '\n\nWould you like to open the error log for more details?\n'
912
+ PROCESS_COMPLETE = '\nProcess complete\n'
913
+ PROCESS_COMPLETE_2 = 'Process complete\n'
914
+
915
+ # GUI Text Constants
916
+ BACK_TO_MAIN_MENU = 'Back to Main Menu'
917
+
918
+ # Help Hint Text
919
+ INTERNAL_MODEL_ATT = 'This is an internal model setting. \n\n***Avoid changing it unless you\'re certain about it!***'
920
+ STOP_HELP = 'Stops ongoing tasks.\n• A confirmation pop-up will appear before stopping.'
921
+ SETTINGS_HELP = 'Accesses the main settings and the "Download Center."'
922
+ COMMAND_TEXT_HELP = 'Shows the status and progress of ongoing tasks.'
923
+ SAVE_CURRENT_SETTINGS_HELP = 'Load or save the app\'s settings.'
924
+ PITCH_SHIFT_HELP = ('Choose the pitch for processing tracks:\n\n'
925
+ '• Whole numbers indicate semitones.\n'
926
+ '• Using higher pitches may cut the upper bandwidth, even in high-quality models.\n'
927
+ '• Upping the pitch can be better for tracks with deeper vocals.\n'
928
+ '• Dropping the pitch may take more processing time but works well for tracks with high-pitched vocals.')
929
+ AGGRESSION_SETTING_HELP = ('Adjust the intensity of primary stem extraction:\n\n'
930
+ '• It ranges from -100 - 100.\n'
931
+ '• Bigger values mean deeper extractions.\n'
932
+ '• Typically, it\'s set to 5 for vocals & instrumentals. \n'
933
+ '• Values beyond 5 might muddy the sound for non-vocal models.')
934
+ WINDOW_SIZE_HELP = ('Select window size to balance quality and speed:\n\n'
935
+ '• 1024 - Quick but lesser quality.\n'
936
+ '• 512 - Medium speed and quality.\n'
937
+ '• 320 - Takes longer but may offer better quality.')
938
+ MDX_SEGMENT_SIZE_HELP = ('Pick a segment size to balance speed, resource use, and quality:\n'
939
+ '• Smaller sizes consume less resources.\n'
940
+ '• Bigger sizes consume more resources, but may provide better results.\n'
941
+ '• Default size is 256. Quality can change based on your pick.')
942
+ DEMUCS_STEMS_HELP = ('Select a stem for extraction with the chosen model:\n\n'
943
+ '• All Stems - Extracts all available stems.\n'
944
+ '• Vocals - Only the "vocals" stem.\n'
945
+ '• Other - Only the "other" stem.\n'
946
+ '• Bass - Only the "bass" stem.\n'
947
+ '• Drums - Only the "drums" stem.')
948
+ SEGMENT_HELP = ('Adjust segments to manage RAM or V-RAM usage:\n\n'
949
+ '• Smaller sizes consume less resources.\n'
950
+ '• Bigger sizes consume more resources, but may provide better results.\n'
951
+ '• "Default" picks the optimal size.')
952
+
953
+ ENSEMBLE_MAIN_STEM_HELP = (
954
+ 'Select the stem type for ensembling:\n\n'
955
+
956
+ f'• {VOCAL_PAIR}:\n'
957
+ ' - Primary Stem: Vocals\n'
958
+ ' - Secondary Stem: Instrumental (mixture minus vocals)\n\n'
959
+
960
+ f'• {OTHER_PAIR}:\n'
961
+ ' - Primary Stem: Other\n'
962
+ ' - Secondary Stem: No Other (mixture minus "other")\n\n'
963
+
964
+ f'• {BASS_PAIR}:\n'
965
+ ' - Primary Stem: Bass\n'
966
+ ' - Secondary Stem: No Bass (mixture minus bass)\n\n'
967
+
968
+ f'• {DRUM_PAIR}:\n'
969
+ ' - Primary Stem: Drums\n'
970
+ ' - Secondary Stem: No Drums (mixture minus drums)\n\n'
971
+
972
+ f'• {FOUR_STEM_ENSEMBLE}:\n'
973
+ ' - Gathers all 4-stem Demucs models and ensembles all outputs.\n\n'
974
+
975
+ f'• {MULTI_STEM_ENSEMBLE}:\n'
976
+ ' - The "Jungle Ensemble" gathers all models and ensembles any related outputs.'
977
+ )
978
+
979
+ ENSEMBLE_TYPE_HELP = (
980
+ 'Choose the ensemble algorithm for generating the final output:\n\n'
981
+
982
+ f'• {MAX_MIN}:\n'
983
+ ' - Primary stem processed with "Max Spec" algorithm.\n'
984
+ ' - Secondary stem processed with "Min Spec" algorithm.\n\n'
985
+
986
+ 'Note: For the "4 Stem Ensemble" option, only one algorithm will be displayed.\n\n'
987
+
988
+ 'Algorithm Details:\n'
989
+
990
+ f'• {MAX_SPEC}:\n'
991
+ ' - Produces the highest possible output.\n'
992
+ ' - Ideal for vocal stems for a fuller sound, but might introduce unwanted artifacts.\n'
993
+ ' - Works well with instrumental stems, but avoid using VR Arch models in the ensemble.\n\n'
994
+
995
+ f'• {MIN_SPEC}:\n'
996
+ ' - Produces the lowest possible output.\n'
997
+ ' - Ideal for instrumental stems for a cleaner result. Might result in a "muddy" sound.\n\n'
998
+
999
+ f'• {AUDIO_AVERAGE}:\n'
1000
+ ' - Averages all results together for the final output.'
1001
+ )
1002
+
1003
+ ENSEMBLE_LISTBOX_HELP = (
1004
+ 'Displays all available models for the chosen main stem pair.'
1005
+ )
1006
+
1007
+ if OPERATING_SYSTEM == 'darwin':
1008
+ IS_GPU_CONVERSION_HELP = (
1009
+ '• Use GPU for Processing (if available):\n'
1010
+ ' - If checked, the application will attempt to use your GPU for faster processing.\n'
1011
+ ' - If a GPU is not detected, it will default to CPU processing.\n'
1012
+ ' - GPU processing for MacOS only works with VR Arch models.\n\n'
1013
+ '• Please Note:\n'
1014
+ ' - CPU processing is significantly slower than GPU processing.\n'
1015
+ ' - Only Macs with M1 chips can be used for GPU processing.'
1016
+ )
1017
+ else:
1018
+ IS_GPU_CONVERSION_HELP = (
1019
+ '• Use GPU for Processing (if available):\n'
1020
+ ' - If checked, the application will attempt to use your GPU for faster processing.\n'
1021
+ ' - If a GPU is not detected, it will default to CPU processing.\n\n'
1022
+ '• Please Note:\n'
1023
+ ' - CPU processing is significantly slower than GPU processing.\n'
1024
+ ' - Only Nvidia GPUs can be used for GPU processing.'
1025
+ )
1026
+
1027
+ IS_TIME_CORRECTION_HELP = ('When checked, the output will retain the original BPM of the input.')
1028
+ SAVE_STEM_ONLY_HELP = 'Allows the user to save only the selected stem.'
1029
+ IS_NORMALIZATION_HELP = 'Normalizes output to prevent clipping.'
1030
+ IS_CUDA_SELECT_HELP = "If you have more than one GPU, you can pick which one to use for processing."
1031
+ CROP_SIZE_HELP = '**Only compatible with select models only!**\n\n Setting should match training crop-size value. Leave as is if unsure.'
1032
+ IS_TTA_HELP = ('This option performs Test-Time-Augmentation to improve the separation quality.\n\n'
1033
+ 'Note: Having this selected will increase the time it takes to complete a conversion')
1034
+ IS_POST_PROCESS_HELP = ('This option can potentially identify leftover instrumental artifacts within the vocal outputs. \nThis option may improve the separation of some songs.\n\n' +\
1035
+ 'Note: Selecting this option can adversely affect the conversion process, depending on the track. Because of this, it is only recommended as a last resort.')
1036
+ IS_HIGH_END_PROCESS_HELP = 'The application will mirror the missing frequency range of the output.'
1037
+ SHIFTS_HELP = ('Performs multiple predictions with random shifts of the input and averages them.\n\n'
1038
+ '• The higher number of shifts, the longer the prediction will take. \n- Not recommended unless you have a GPU.')
1039
+ OVERLAP_HELP = ('• This option controls the amount of overlap between prediction windows.\n'
1040
+ ' - Higher values can provide better results, but will lead to longer processing times.\n'
1041
+ ' - You can choose between 0.001-0.999')
1042
+ MDX_OVERLAP_HELP = ('• This option controls the amount of overlap between prediction windows.\n'
1043
+ ' - Higher values can provide better results, but will lead to longer processing times.\n'
1044
+ ' - For Non-MDX23C models: You can choose between 0.001-0.999')
1045
+ OVERLAP_23_HELP = ('• This option controls the amount of overlap between prediction windows.\n'
1046
+ ' - Higher values can provide better results, but will lead to longer processing times.')
1047
+ IS_SEGMENT_DEFAULT_HELP = '• The segment size is set based on the value provided in a chosen model\'s associated \nconfig file (yaml).'
1048
+ IS_SPLIT_MODE_HELP = '• Enables \"Segments\". \n• Deselecting this option is only recommended for those with powerful PCs.'
1049
+ IS_DEMUCS_COMBINE_STEMS_HELP = 'The application will create the secondary stem by combining the remaining stems \ninstead of inverting the primary stem with the mixture.'
1050
+ COMPENSATE_HELP = 'Compensates the audio of the primary stems to allow for a better secondary stem.'
1051
+ IS_DENOISE_HELP = ('• Standard: This setting reduces the noise created by MDX-Net models.\n'
1052
+ ' - This option only reduces noise in non-MDX23 models.\n'
1053
+ '• Denoise Model: This setting employs a special denoise model to eliminate noise produced by any MDX-Net model.\n'
1054
+ ' - This option works on all MDX-Net models.\n'
1055
+ ' - You must have the "UVR-DeNoise-Lite" VR Arch model installed to use this option.\n'
1056
+ '• Please Note: Both options will increase separation time.')
1057
+
1058
+ VOC_SPLIT_MODEL_SELECT_HELP = '• Select a model from the list of lead and backing vocal models to run through vocal stems automatically.'
1059
+ IS_VOC_SPLIT_INST_SAVE_SELECT_HELP = '• When activated, you will receive extra instrumental outputs that include: one with just the lead vocals and another with only the backing vocals.'
1060
+ IS_VOC_SPLIT_MODEL_SELECT_HELP = ('• When activated, this option auto-processes generated vocal stems, using either a karaoke model to remove lead vocals or another to remove backing vocals.\n'
1061
+ ' - This option splits the vocal track into two separate parts: lead vocals and backing vocals, providing two extra vocal outputs.\n'
1062
+ ' - The results will be organized in the same way, whether you use a karaoke model or a background vocal model.\n'
1063
+ ' - This option does not work in ensemble mode at this time.')
1064
+ IS_DEVERB_OPT_HELP = ('• Select the vocal type you wish to deverb automatically.\n'
1065
+ ' - Example: Choosing "Lead Vocals Only" will only remove reverb from a lead vocal stem.')
1066
+ IS_DEVERB_VOC_HELP = ('• This option removes reverb from a vocal stem.\n'
1067
+ ' - You must have the "UVR-DeEcho-DeReverb" VR Arch model installed to use this option.\n'
1068
+ ' - This option does not work in ensemble mode at this time.')
1069
+ IS_FREQUENCY_MATCH_HELP = 'Matches the frequency cut-off of the primary stem to that of the secondary stem.'
1070
+ CLEAR_CACHE_HELP = 'Clears settings for unrecognized models chosen by the user.'
1071
+ IS_SAVE_ALL_OUTPUTS_ENSEMBLE_HELP = 'If enabled, all individual ensemble-generated outputs are retained.'
1072
+ IS_APPEND_ENSEMBLE_NAME_HELP = 'When enabled, the ensemble name is added to the final output.'
1073
+ IS_WAV_ENSEMBLE_HELP = (
1074
+ 'Processes ensemble algorithms with waveforms instead of spectrograms when activated:\n'
1075
+ '• Might lead to increased distortion.\n'
1076
+ '• Waveform ensembling is faster than spectrogram ensembling.'
1077
+ )
1078
+ DONATE_HELP = 'Opens official UVR "Buy Me a Coffee" external link for project donations!'
1079
+ IS_INVERT_SPEC_HELP = (
1080
+ 'Potentially enhances the secondary stem quality:\n'
1081
+ '• Inverts primary stem using spectrograms, instead of waveforms.\n'
1082
+ '• Slightly slower inversion method.'
1083
+ )
1084
+ IS_TESTING_AUDIO_HELP = 'Appends a 10-digit number to saved files to avoid accidental overwrites.'
1085
+ IS_MODEL_TESTING_AUDIO_HELP = 'Appends the model name to outputs for comparison across different models.'
1086
+ IS_ACCEPT_ANY_INPUT_HELP = (
1087
+ 'Allows all types of inputs when enabled, even non-audio formats.\n'
1088
+ 'For experimental use only. Not recommended for regular use.'
1089
+ )
1090
+ IS_TASK_COMPLETE_HELP = 'Plays a chime upon process completion or failure when activated.'
1091
+ DELETE_YOUR_SETTINGS_HELP = (
1092
+ 'Contains your saved settings. Confirmation will be requested before deleting a selected setting.'
1093
+ )
1094
+ SET_STEM_NAME_HELP = 'Select the primary stem for the given model.'
1095
+ IS_CREATE_MODEL_FOLDER_HELP = ('Two new directories will be generated for the outputs in the export directory after each conversion.\n\n'
1096
+ '• Example: \n'
1097
+ '─ Export Directory\n'
1098
+ ' └── First Directory (Named after the model)\n'
1099
+ ' └── Second Directory (Named after the track)\n'
1100
+ ' └── Output File(s)')
1101
+ MDX_DIM_T_SET_HELP = INTERNAL_MODEL_ATT
1102
+ MDX_DIM_F_SET_HELP = INTERNAL_MODEL_ATT
1103
+
1104
+ MDX_N_FFT_SCALE_SET_HELP = 'Specify the N_FFT size used during model training.'
1105
+ POPUP_COMPENSATE_HELP = (
1106
+ f'Select the appropriate volume compensation for the chosen model.\n'
1107
+ f'Reminder: {COMPENSATE_HELP}'
1108
+ )
1109
+ VR_MODEL_PARAM_HELP = 'Select the required parameters to run the chosen model.'
1110
+ CHOSEN_ENSEMBLE_HELP = (
1111
+ 'Default Ensemble Selections:\n'
1112
+ '• Save the current ensemble configuration.\n'
1113
+ '• Clear all selected models.\n'
1114
+ 'Note: You can also select previously saved ensembles.'
1115
+ )
1116
+ CHOSEN_PROCESS_METHOD_HELP = (
1117
+ 'Choose a Processing Method:\n'
1118
+ 'Select from various AI networks and algorithms to process your track:\n'
1119
+ '\n'
1120
+ '• VR Architecture: Uses magnitude spectrograms for source separation.\n'
1121
+ '• MDX-Net: Employs a Hybrid Spectrogram network for source separation.\n'
1122
+ '• Demucs v3: Also utilizes a Hybrid Spectrogram network for source separation.\n'
1123
+ '• Ensemble Mode: Combine results from multiple models and networks for optimal results.\n'
1124
+ '• Audio Tools: Additional utilities for added convenience.'
1125
+ )
1126
+
1127
+ INPUT_FOLDER_ENTRY_HELP = (
1128
+ 'Select Input:\n'
1129
+ 'Choose the audio file(s) you want to process.'
1130
+ )
1131
+ INPUT_FOLDER_ENTRY_HELP_2 = (
1132
+ 'Input Option Menu:\n'
1133
+ 'Click to access the input option menu.'
1134
+ )
1135
+ OUTPUT_FOLDER_ENTRY_HELP = (
1136
+ 'Select Output:\n'
1137
+ 'Choose the directory where the processed files will be saved.'
1138
+ )
1139
+ INPUT_FOLDER_BUTTON_HELP = (
1140
+ 'Open Input Folder Button:\n'
1141
+ 'Open the directory containing the selected input audio file(s).'
1142
+ )
1143
+ OUTPUT_FOLDER_BUTTON_HELP = (
1144
+ 'Open Output Folder Button:\n'
1145
+ 'Open the selected output folder.'
1146
+ )
1147
+ CHOOSE_MODEL_HELP = (
1148
+ 'Each processing method has its own set of options and models.\n'
1149
+ 'Choose the model associated with the selected processing method here.'
1150
+ )
1151
+ FORMAT_SETTING_HELP = 'Save Outputs As: '
1152
+ SECONDARY_MODEL_ACTIVATE_HELP = (
1153
+ 'When enabled, the application will perform an additional inference using the selected model(s) above.'
1154
+ )
1155
+ SECONDARY_MODEL_HELP = (
1156
+ 'Choose the Secondary Model:\n'
1157
+ 'Select the secondary model associated with the stem you want to process with the current method.'
1158
+ )
1159
+
1160
+ INPUT_SEC_FIELDS_HELP = (
1161
+ 'Right click here to choose your inputs!'
1162
+ )
1163
+
1164
+ SECONDARY_MODEL_SCALE_HELP = ('The scale determines how the final audio outputs will be averaged between the primary and secondary models.\n\nFor example:\n\n'
1165
+ '• 10% - 10 percent of the main model result will be factored into the final result.\n'
1166
+ '• 50% - The results from the main and secondary models will be averaged evenly.\n'
1167
+ '• 90% - 90 percent of the main model result will be factored into the final result.')
1168
+ PRE_PROC_MODEL_ACTIVATE_HELP = (
1169
+ 'When enabled, the application will use the selected model to isolate the instrumental stem.\n'
1170
+ 'Subsequently, all non-vocal stems will be extracted from this generated instrumental.\n'
1171
+ '\n'
1172
+ 'Key Points:\n'
1173
+ '• This feature can significantly reduce vocal bleed in non-vocal stems.\n'
1174
+ '• Available exclusively in the Demucs tool.\n'
1175
+ '• Compatible only with non-vocal and non-instrumental stem outputs.\n'
1176
+ '• Expect an increase in total processing time.\n'
1177
+ '• Only the VR or MDX-Net Vocal Instrumental/Vocals models can be chosen for this process.'
1178
+ )
1179
+
1180
+ AUDIO_TOOLS_HELP = (
1181
+ 'Select from various audio tools to process your track:\n'
1182
+ '\n'
1183
+ '• Manual Ensemble: Requires 2 or more selected files as inputs. This allows tracks to be processed using the algorithms from Ensemble Mode.\n'
1184
+ '• Time Stretch: Adjust the playback speed of the selected inputs to be faster or slower.\n'
1185
+ '• Change Pitch: Modify the pitch of the selected inputs.\n'
1186
+ '• Align Inputs: Choose 2 audio file and the application will align them and provide the difference in alignment.\n'
1187
+ ' - This tool provides similar functionality to "Utagoe."\n'
1188
+ ' - Primary Audio: This is usually a mixture.\n'
1189
+ ' - Secondary Audio: This is usually an instrumental.\n'
1190
+ '• Matchering: Choose 2 audio files. The matchering algorithm will master the target audio to have the same RMS, FR, peak amplitude, and stereo width as the reference audio.'
1191
+ )
1192
+
1193
+ PRE_PROC_MODEL_INST_MIX_HELP = 'When enabled, the application will generate a third output without the selected stem and vocals.'
1194
+ MODEL_SAMPLE_MODE_HELP = ('Allows the user to process only part of a track to sample settings or a model without running a full conversion.\n\nNotes:\n\n'
1195
+ '• The number in the parentheses is the current number of seconds the generated sample will be.\n'
1196
+ '• You can choose the number of seconds to extract from the track in the \"Additional Settings\" menu.')
1197
+
1198
+ POST_PROCESS_THREASHOLD_HELP = ('Allows the user to control the intensity of the Post_process option.\n\nNotes:\n\n'
1199
+ '• Higher values potentially remove more artifacts. However, bleed might increase.\n'
1200
+ '• Lower values limit artifact removal.')
1201
+
1202
+ BATCH_SIZE_HELP = ('Specify the number of batches to be processed at a time.\n\nNotes:\n\n'
1203
+ '• Higher values mean more RAM usage but slightly faster processing times.\n'
1204
+ '• Lower values mean less RAM usage but slightly longer processing times.\n'
1205
+ '• Batch size value has no effect on output quality.')
1206
+
1207
+ VR_MODEL_NOUT_HELP = ""
1208
+ VR_MODEL_NOUT_LSTM_HELP = ""
1209
+
1210
+ IS_PHASE_HELP = 'Select the phase for the secondary audio.\n• Note: Using the "Automatic" option is strongly recommended.'
1211
+ IS_ALIGN_TRACK_HELP = 'Enable this to save the secondary track once aligned.'
1212
+ IS_MATCH_SILENCE_HELP = (
1213
+ 'Aligns the initial silence of the secondary audio with the primary audio.\n'
1214
+ '• Note: Avoid using this option if the primary audio begins solely with vocals.'
1215
+ )
1216
+ IS_MATCH_SPEC_HELP = 'Align the secondary audio based on the primary audio\'s spectrogram.\n• Note: This may enhance alignment in specific cases.'
1217
+
1218
+ TIME_WINDOW_ALIGN_HELP = (
1219
+ 'This setting determines the window size for alignment analysis, especially for pairs with minor timing variations:\n'
1220
+ '\n'
1221
+ '• None: Disables time window analysis.\n'
1222
+ '• 1: Analyzes pair by 0.0625-second windows.\n'
1223
+ '• 2: Analyzes pair by 0.125-second windows.\n'
1224
+ '• 3: Analyzes pair by 0.25-second windows.\n'
1225
+ '• 4: Analyzes pair by 0.50-second windows.\n'
1226
+ '• 5: Analyzes pair by 0.75-second windows.\n'
1227
+ '• 6: Analyzes pair by 1-second windows.\n'
1228
+ '• 7: Analyzes pair by 2-second windows.\n'
1229
+ '\n'
1230
+ 'Shifts Options:\n'
1231
+ '• Low: Cycles through 0.0625 and 0.5-second windows to find an optimal match.\n'
1232
+ '• Medium: Cycles through 0.0625, 0.125, and 0.5-second windows to find an optimal match.\n'
1233
+ '• High: Cycles through 0.0625, 0.125, 0.25, and 0.5-second windows to find an optimal match.\n'
1234
+ '\n'
1235
+ 'Important Points to Consider:\n'
1236
+ ' - Using the "Shifts" option may require more processing time and might not guarantee better results.\n'
1237
+ ' - Opting for smaller analysis windows can increase processing times.\n'
1238
+ ' - The best settings are likely to vary based on the specific tracks being processed.'
1239
+ )
1240
+ INTRO_ANALYSIS_ALIGN_HELP = (
1241
+ 'This setting determines the portion of the audio input to be analyzed for initial alignment.\n'
1242
+ '\n'
1243
+ '• Default: Analyzes 10% (or 1/10th) of the audio\'s total length.\n'
1244
+ '• 1: Analyzes 12.5% (or 1/8th) of the audio\'s total length.\n'
1245
+ '• 2: Analyzes 16.67% (or 1/6th) of the audio\'s total length.\n'
1246
+ '• 3: Analyzes 25% (or 1/4th) of the audio\'s total length.\n'
1247
+ '• 4: Analyzes 50% (or half) of the audio\'s total length.\n'
1248
+ '\n'
1249
+ 'Shifts Options:\n'
1250
+ '• Low: Cycles through 2 intro analysis values.\n'
1251
+ '• Medium: Cycles through 3 intro analysis values.\n'
1252
+ '• High: Cycles through 5 intro analysis values.\n'
1253
+ '\n'
1254
+ 'Important Points to Consider:\n'
1255
+ ' - Using the "Shifts" option will require more processing time and might not guarantee better results.\n'
1256
+ ' - Optimal settings may vary depending on the specific tracks being processed.'
1257
+ )
1258
+
1259
+ VOLUME_ANALYSIS_ALIGN_HELP = (
1260
+ 'This setting specifies the volume adjustments to be made on the secondary input:\n'
1261
+ '\n'
1262
+ '• None: No volume adjustments are made.\n'
1263
+ '• Low: Analyzes the audio within a 4dB range, adjusting in 1dB increments.\n'
1264
+ '• Medium: Analyzes the audio within a 6dB range, adjusting in 1dB increments.\n'
1265
+ '• High: Analyzes the audio within a 6dB range, adjusting in 0.5dB increments.\n'
1266
+ '• Very High: Analyzes the audio within a 10dB range, adjusting in 0.5dB increments.\n'
1267
+ '\n'
1268
+ 'Important Points to Consider:\n'
1269
+ ' - Selecting more extensive analysis options (e.g., High, Very High) will lead to longer processing times.\n'
1270
+ ' - Optimal settings might vary based on the specific tracks being processed.'
1271
+ )
1272
+
1273
+ PHASE_SHIFTS_ALIGN_HELP = (
1274
+ 'This setting specifies the phase adjustments to be made on the secondary input:\n'
1275
+ '\n'
1276
+ 'Shifts Options:\n'
1277
+ '• None: No phase adjustments are made.\n'
1278
+ '• Very Low: Analyzes the audio within range of 2 different phase positions.\n'
1279
+ '• Low: Analyzes the audio within range of 4 different phase positions.\n'
1280
+ '• Medium: Analyzes the audio within range of 8 different phase positions.\n'
1281
+ '• High: Analyzes the audio within range of 18 different phase positions.\n'
1282
+ '• Very High: Analyzes the audio within range of 36 different phase positions.\n'
1283
+ '• Maximum: Analyzes the audio in all 360 phase positions.\n'
1284
+ '\n'
1285
+ 'Important Points to Consider:\n'
1286
+ ' - This option only works with time correction.\n'
1287
+ ' - This option can be helpful if one of the inputs were from an analog source.\n'
1288
+ ' - Selecting more extensive analysis options (e.g., High, Very High) will lead to longer processing times.\n'
1289
+ ' - Selecting "Maximum" can take hours to process.\n'
1290
+ ' - Optimal settings might vary based on the specific tracks being processed.'
1291
+ )
1292
+
1293
+ # Warning Messages
1294
+ STORAGE_ERROR = 'Insufficient Storage', 'There is not enough storage on main drive to continue. Your main drive must have at least 3 GB\'s of storage in order for this application function properly. \n\nPlease ensure your main drive has at least 3 GB\'s of storage and try again.\n\n'
1295
+ STORAGE_WARNING = 'Available Storage Low', 'Your main drive is running low on storage. Your main drive must have at least 3 GB\'s of storage in order for this application function properly.\n\n'
1296
+ CONFIRM_WARNING = '\nAre you sure you wish to continue?'
1297
+ PROCESS_FAILED = 'Process failed, please see error log\n'
1298
+ EXIT_PROCESS_ERROR = 'Active Process', 'Please stop the active process or wait for it to complete before you exit.'
1299
+ EXIT_HALTED_PROCESS_ERROR = 'Halting Process', 'Please wait for the application to finish halting the process before exiting.'
1300
+ EXIT_DOWNLOAD_ERROR = 'Active Download', 'Please stop the download or wait for it to complete before you exit.'
1301
+ SET_TO_DEFAULT_PROCESS_ERROR = 'Active Process', 'You cannot reset all of the application settings during an active process.'
1302
+ SET_TO_ANY_PROCESS_ERROR = 'Active Process', 'You cannot reset the application settings during an active process.'
1303
+ RESET_ALL_TO_DEFAULT_WARNING = 'Reset Settings Confirmation', 'All application settings will be set to factory default.\n\nAre you sure you wish to continue?'
1304
+ AUDIO_VERIFICATION_CHECK = lambda i, e:f'++++++++++++++++++++++++++++++++++++++++++++++++++++\n\nBroken File Removed: \n\n{i}\n\nError Details:\n\n{e}\n++++++++++++++++++++++++++++++++++++++++++++++++++++'
1305
+ INVALID_ONNX_MODEL_ERROR = 'Invalid Model', 'The file selected is not a valid MDX-Net model. Please see the error log for more information.'
1306
+ INVALID_PARAM_MODEL_ERROR = 'Select Model Param', 'Please choose a model param or click \'Cancel\'.'
1307
+ UNRECOGNIZED_MODEL = 'Unrecognized Model Detected', ' is an unrecognized model.\n\n' + \
1308
+ 'Would you like to select the correct parameters before continuing?'
1309
+ STOP_PROCESS_CONFIRM = 'Confirmation', 'You are about to stop all active processes.\n\nAre you sure you wish to continue?'
1310
+ NO_ENSEMBLE_SELECTED = 'No Models Selected', 'Please select ensemble and try again.'
1311
+ PICKLE_CORRU = 'File Corrupted', 'Unable to load this ensemble.\n\n' + \
1312
+ 'Would you like to remove this ensemble from your list?'
1313
+ DELETE_ENS_ENTRY = 'Confirm Removal', 'Are you sure you want to remove this entry?'
1314
+
1315
+ # Separation Text
1316
+ LOADING_MODEL = 'Loading model...'
1317
+ INFERENCE_STEP_1 = 'Running inference...'
1318
+ INFERENCE_STEP_1_SEC = 'Running inference (secondary model)...'
1319
+ INFERENCE_STEP_1_4_STEM = lambda stem:f'Running inference (secondary model for {stem})...'
1320
+ INFERENCE_STEP_1_PRE = 'Running inference (pre-process model)...'
1321
+ INFERENCE_STEP_1_VOC_S = 'Splitting vocals...'
1322
+ INFERENCE_STEP_2_PRE = lambda pm, m:f'Loading pre-process model ({pm}: {m})...'
1323
+ INFERENCE_STEP_2_SEC = lambda pm, m:f'Loading secondary model ({pm}: {m})...'
1324
+ INFERENCE_STEP_2_VOC_S = lambda pm, m:f'Loading vocal splitter model ({pm}: {m})...'
1325
+ INFERENCE_STEP_2_SEC_CACHED_MODOEL = lambda pm, m:f'Secondary model ({pm}: {m}) cache loaded.\n'
1326
+ INFERENCE_STEP_2_PRE_CACHED_MODOEL = lambda pm, m:f'Pre-process model ({pm}: {m}) cache loaded.\n'
1327
+ INFERENCE_STEP_2_SEC_CACHED = 'Loading cached secondary model source(s)... Done!\n'
1328
+ INFERENCE_STEP_2_PRIMARY_CACHED = ' Model cache loaded.\n'
1329
+ INFERENCE_STEP_2 = 'Inference complete.'
1330
+ INFERENCE_STEP_DEVERBING = ' Deverbing...'
1331
+ SAVING_STEM = 'Saving ', ' stem...'
1332
+ SAVING_ALL_STEMS = 'Saving all stems...'
1333
+ ENSEMBLING_OUTPUTS = 'Ensembling outputs...'
1334
+ DONE = ' Done!\n'
1335
+ ENSEMBLES_SAVED = 'Ensembled outputs saved!\n\n'
1336
+
1337
+ #Additional Text
1338
+ CHOOSE_PROC_METHOD_MAIN_LABEL = 'CHOOSE PROCESS METHOD'
1339
+ SELECT_SAVED_SETTINGS_MAIN_LABEL = 'SELECT SAVED SETTINGS'
1340
+ CHOOSE_MDX_MODEL_MAIN_LABEL = 'CHOOSE MDX-NET MODEL'
1341
+ BATCHES_MDX_MAIN_LABEL = 'BATCH SIZE'
1342
+ VOL_COMP_MDX_MAIN_LABEL = 'VOLUME COMPENSATION'
1343
+ SEGMENT_MDX_MAIN_LABEL = 'SEGMENT SIZE'
1344
+ SELECT_VR_MODEL_MAIN_LABEL = 'CHOOSE VR MODEL'
1345
+ AGGRESSION_SETTING_MAIN_LABEL = 'AGGRESSION SETTING'
1346
+ WINDOW_SIZE_MAIN_LABEL = 'WINDOW SIZE'
1347
+ CHOOSE_DEMUCS_MODEL_MAIN_LABEL = 'CHOOSE DEMUCS MODEL'
1348
+ CHOOSE_STEMS_MAIN_LABEL = 'CHOOSE STEM(S)'
1349
+ CHOOSE_SEGMENT_MAIN_LABEL = 'SEGMENT'
1350
+ ENSEMBLE_OPTIONS_MAIN_LABEL = 'ENSEMBLE OPTIONS'
1351
+ CHOOSE_MAIN_PAIR_MAIN_LABEL = 'MAIN STEM PAIR'
1352
+ CHOOSE_ENSEMBLE_ALGORITHM_MAIN_LABEL = 'ENSEMBLE ALGORITHM'
1353
+ AVAILABLE_MODELS_MAIN_LABEL = 'AVAILABLE MODELS'
1354
+ CHOOSE_AUDIO_TOOLS_MAIN_LABEL = 'CHOOSE AUDIO TOOL'
1355
+ CHOOSE_MANUAL_ALGORITHM_MAIN_LABEL = 'CHOOSE ALGORITHM'
1356
+ CHOOSE_RATE_MAIN_LABEL = 'RATE'
1357
+ CHOOSE_SEMITONES_MAIN_LABEL = 'SEMITONES'
1358
+ GPU_CONVERSION_MAIN_LABEL = 'GPU Conversion'
1359
+ CHANGE_LOG_HEADER = lambda patch:f"Patch Version:\n\n{patch}"
1360
+ INVALID_INPUT_E = ' Invalid input! '
1361
+ LB_UP = "Move Selection Up"
1362
+ LB_DOWN = "Move Selection Down"
1363
+ LB_CLEAR = "Clear Box"
1364
+ LB_MOVE_OVER_P = "Move Selection to Secondary List"
1365
+ LB_MOVE_OVER_S = "Move Selection to Primary List"
1366
+ FILE_ONE_MAIN_LABEL = "PRIMARY AUDIO"
1367
+ FILE_TWO_MAIN_LABEL = "SECONDARY AUDIO"
1368
+ FILE_ONE_MATCH_MAIN_LABEL = "TARGET AUDIO"
1369
+ FILE_TWO_MATCH_MAIN_LABEL = "REFERENCE AUDIO"
1370
+ TIME_WINDOW_MAIN_LABEL = "TIME ADJUSTMENT"
1371
+ INTRO_ANALYSIS_MAIN_LABEL = "INTRO ANALYSIS"
1372
+ VOLUME_ADJUSTMENT_MAIN_LABEL = "VOLUME ADJUSTMENT"
1373
+ SELECT_INPUTS = "Select Input(s)"
1374
+ SELECTED_INPUTS = 'Selected Inputs'
1375
+ WIDEN_BOX = 'Widen Box'
1376
+ CONFIRM_ENTRIES = 'Confirm Entries'
1377
+ CLOSE_WINDOW = 'Close Window'
1378
+ DUAL_AUDIO_PROCESSING = 'Dual Audio Batch Processing'
1379
+ CANCEL_TEXT = "Cancel"
1380
+ CONFIRM_TEXT = "Confirm"
1381
+ SELECT_MODEL_TEXT = 'Select Model'
1382
+ NONE_SELECTED = 'None Selected'
1383
+ SAVE_TEXT = 'Save'
1384
+ OVERLAP_TEXT = 'Overlap'
1385
+ ACCEPT_ANY_INPUT_TEXT = 'Accept Any Input'
1386
+ ACTIVATE_PRE_PROCESS_MODEL_TEXT = 'Activate Pre-process Model'
1387
+ ACTIVATE_SECONDARY_MODEL_TEXT = 'Activate Secondary Model'
1388
+ ADDITIONAL_MENUS_INFORMATION_TEXT = 'Additional Menus & Information'
1389
+ ADDITIONAL_SETTINGS_TEXT = 'Additional Settings'
1390
+ ADVANCED_ALIGN_TOOL_OPTIONS_TEXT = 'Advanced Align Tool Options'
1391
+ ADVANCED_DEMUCS_OPTIONS_TEXT = 'Advanced Demucs Options'
1392
+ ADVANCED_ENSEMBLE_OPTIONS_TEXT = 'Advanced Ensemble Options'
1393
+ ADVANCED_MDXNET23_OPTIONS_TEXT = 'Advanced MDX-NET23 Options'
1394
+ ADVANCED_MDXNET_OPTIONS_TEXT = 'Advanced MDX-Net Options'
1395
+ ADVANCED_OPTION_MENU_TEXT = 'Advanced Option Menu'
1396
+ ADVANCED_VR_OPTIONS_TEXT = 'Advanced VR Options'
1397
+ AGGRESSION_SETTING_TEXT = 'Aggression Setting'
1398
+ APPEND_ENSEMBLE_NAME_TEXT = 'Append Ensemble Name'
1399
+ APPLICATION_DOWNLOAD_CENTER_TEXT = 'Application Download Center'
1400
+ APPLICATION_UPDATES_TEXT = 'Application Updates'
1401
+ AUDIO_FORMAT_SETTINGS_TEXT = 'Audio Format Settings'
1402
+ BALANCE_VALUE_TEXT = 'Balance Value'
1403
+ BATCH_SIZE_TEXT = 'Batch Size'
1404
+ BV_MODEL_TEXT = 'BV Model'
1405
+ CHANGE_MODEL_DEFAULT_TEXT = 'Change Model Default'
1406
+ CHANGE_MODEL_DEFAULTS_TEXT = 'Change Model Defaults'
1407
+ CHANGE_PARAMETERS_TEXT = 'Change Parameters'
1408
+ CHOOSE_ADVANCED_MENU_TEXT = 'Choose Advanced Menu'
1409
+ CHOOSE_MODEL_PARAM_TEXT = 'Choose Model Param'
1410
+ CLEAR_AUTOSET_CACHE_TEXT = 'Clear Auto-Set Cache'
1411
+ COMBINE_STEMS_TEXT = 'Combine Stems'
1412
+ CONFIRM_UPDATE_TEXT = 'Confirm Update'
1413
+ COPIED_TEXT = 'Copied!'
1414
+ COPY_ALL_TEXT_TEXT = 'Copy All Text'
1415
+ DEFINED_PARAMETERS_DELETED_TEXT = 'Defined Parameters Deleted'
1416
+ DELETE_PARAMETERS_TEXT = 'Delete Parameters'
1417
+ DELETE_USER_SAVED_SETTING_TEXT = 'Delete User Saved Setting'
1418
+ DEMUCS_TEXT = 'Demucs'
1419
+ DENOISE_OUTPUT_TEXT = 'Denoise Output'
1420
+ DEVERB_VOCALS_TEXT = 'Deverb Vocals'
1421
+ DONE_TEXT = 'Done'
1422
+ DOWNLOAD_CENTER_TEXT = 'Download Center'
1423
+ DOWNLOAD_CODE_TEXT = 'Download Code'
1424
+ DOWNLOAD_LINKS_TEXT = 'Download Link(s)'
1425
+ DOWNLOAD_UPDATE_IN_APPLICATION_TEXT = 'Download Update in Application'
1426
+ ENABLE_HELP_HINTS_TEXT = 'Enable Help Hints'
1427
+ ENABLE_TTA_TEXT = 'Enable TTA'
1428
+ ENABLE_VOCAL_SPLIT_MODE_TEXT = 'Enable Vocal Split Mode'
1429
+ ENSEMBLE_NAME_TEXT = 'Ensemble Name'
1430
+ ENSEMBLE_WAVFORMS_TEXT = 'Ensemble Wavforms'
1431
+ ERROR_CONSOLE_TEXT = 'Error Console'
1432
+ GENERAL_MENU_TEXT = 'General Menu'
1433
+ GENERAL_PROCESS_SETTINGS_TEXT = 'General Process Settings'
1434
+ GENERATE_MODEL_FOLDER_TEXT = 'Generate Model Folder'
1435
+ HIGHEND_PROCESS_TEXT = 'High-End Process'
1436
+ INPUT_CODE_TEXT = 'Input Code'
1437
+ INPUT_STEM_NAME_TEXT = 'Input Stem Name'
1438
+ INPUT_UNIQUE_STEM_NAME_TEXT = 'Input Unique Stem Name'
1439
+ IS_INVERSE_STEM_TEXT = 'Is Inverse Stem'
1440
+ KARAOKE_MODEL_TEXT = 'Karaoke Model'
1441
+ MANUAL_DOWNLOADS_TEXT = 'Manual Downloads'
1442
+ MATCH_FREQ_CUTOFF_TEXT = 'Match Freq Cut-off'
1443
+ MDXNET_C_MODEL_PARAMETERS_TEXT = 'MDX-Net C Model Parameters'
1444
+ MDXNET_MODEL_SETTINGS_TEXT = 'MDX-Net Model Settings'
1445
+ MDXNET_TEXT = 'MDX-Net'
1446
+ MODEL_PARAMETERS_CHANGED_TEXT = 'Model Parameters Changed'
1447
+ MODEL_SAMPLE_MODE_SETTINGS_TEXT = 'Model Sample Mode Settings'
1448
+ MODEL_TEST_MODE_TEXT = 'Model Test Mode'
1449
+ MP3_BITRATE_TEXT = 'Mp3 Bitrate'
1450
+ NAME_SETTINGS_TEXT = 'Name Settings'
1451
+ NO_DEFINED_PARAMETERS_FOUND_TEXT = 'No Defined Parameters Found'
1452
+ NO_TEXT = 'No'
1453
+ NORMALIZE_OUTPUT_TEXT = 'Normalize Output'
1454
+ USE_OPENCL_TEXT = 'Use OpenCL'
1455
+ NOT_ENOUGH_MODELS_TEXT = 'Not Enough Models'
1456
+ NOTIFICATION_CHIMES_TEXT = 'Notification Chimes'
1457
+ OPEN_APPLICATION_DIRECTORY_TEXT = 'Open Application Directory'
1458
+ OPEN_LINK_TO_MODEL_TEXT = 'Open Link to Model'
1459
+ OPEN_MODEL_DIRECTORY_TEXT = 'Open Model Directory'
1460
+ OPEN_MODEL_FOLDER_TEXT = 'Open Model Folder'
1461
+ OPEN_MODELS_FOLDER_TEXT = 'Open Models Folder'
1462
+ PHASE_SHIFTS_TEXT = 'Phase Shifts'
1463
+ POST_PROCESS_TEXT = 'Post-Process'
1464
+ POST_PROCESS_THRESHOLD_TEXT = 'Post-process Threshold'
1465
+ PREPROCESS_MODEL_CHOOSE_TEXT = 'Pre-process Model'
1466
+ PRIMARY_STEM_TEXT = 'Primary Stem'
1467
+ REFRESH_LIST_TEXT = 'Refresh List'
1468
+ REMOVE_SAVED_ENSEMBLE_TEXT = 'Remove Saved Ensemble'
1469
+ REPORT_ISSUE_TEXT = 'Report Issue'
1470
+ RESET_ALL_SETTINGS_TO_DEFAULT_TEXT = 'Reset All Settings to Default'
1471
+ RESTART_APPLICATION_TEXT = 'Restart Application'
1472
+ SAMPLE_CLIP_DURATION_TEXT = 'Sample Clip Duration'
1473
+ SAVE_ALIGNED_TRACK_TEXT = 'Save Aligned Track'
1474
+ SAVE_ALL_OUTPUTS_TEXT = 'Save All Outputs'
1475
+ SAVE_CURRENT_ENSEMBLE_TEXT = 'Save Current Ensemble'
1476
+ SAVE_CURRENT_SETTINGS_TEXT = 'Save Current Settings'
1477
+ SAVE_INSTRUMENTAL_MIXTURE_TEXT = 'Save Instrumental Mixture'
1478
+ SAVE_SPLIT_VOCAL_INSTRUMENTALS_TEXT = 'Save Split Vocal Instrumentals'
1479
+ SECONDARY_MODEL_TEXT = 'Secondary Model'
1480
+ SECONDARY_PHASE_TEXT = 'Secondary Phase'
1481
+ SECONDS_TEXT = 'Seconds'
1482
+ SEGMENT_DEFAULT_TEXT = 'Segment Default'
1483
+ SEGMENT_SIZE_TEXT = 'Segment Size'
1484
+ SEGMENTS_TEXT = 'Segments'
1485
+ SELECT_DOWNLOAD_TEXT = 'Select Download'
1486
+ SELECT_MODEL_PARAM_TEXT = 'Select Model Param'
1487
+ SELECT_VOCAL_TYPE_TO_DEVERB_TEXT = 'Select Vocal Type to Deverb'
1488
+ SELECTED_MODEL_PLACEMENT_PATH_TEXT = 'Selected Model Placement Path'
1489
+ SETTINGS_GUIDE_TEXT = 'Settings Guide'
1490
+ SETTINGS_TEST_MODE_TEXT = 'Settings Test Mode'
1491
+ SHIFT_CONVERSION_PITCH_TEXT = 'Shift Conversion Pitch'
1492
+ SHIFTS_TEXT = 'Shifts'
1493
+ SILENCE_MATCHING_TEXT = 'Silence Matching'
1494
+ SPECIFY_MDX_NET_MODEL_PARAMETERS_TEXT = 'Specify MDX-Net Model Parameters'
1495
+ SPECIFY_PARAMETERS_TEXT = 'Specify Parameters'
1496
+ SPECIFY_VR_MODEL_PARAMETERS_TEXT = 'Specify VR Model Parameters'
1497
+ SPECTRAL_INVERSION_TEXT = 'Spectral Inversion'
1498
+ SPECTRAL_MATCHING_TEXT = 'Spectral Matching'
1499
+ SPLIT_MODE_TEXT = 'Split Mode'
1500
+ STEM_NAME_TEXT = 'Stem Name'
1501
+ STOP_DOWNLOAD_TEXT = 'Stop Download'
1502
+ SUPPORT_UVR_TEXT = 'Support UVR'
1503
+ TRY_MANUAL_DOWNLOAD_TEXT = 'Try Manual Download'
1504
+ UPDATE_FOUND_TEXT = 'Update Found'
1505
+ USER_DOWNLOAD_CODES_TEXT = 'User Download Codes'
1506
+ UVR_BUY_ME_A_COFFEE_LINK_TEXT = 'UVR \'Buy Me a Coffee\' Link'
1507
+ UVR_ERROR_LOG_TEXT = 'UVR Error Log'
1508
+ UVR_PATREON_LINK_TEXT = 'UVR Patreon Link'
1509
+ VOCAL_DEVERB_OPTIONS_TEXT = 'Vocal Deverb Options'
1510
+ VOCAL_SPLIT_MODE_OPTIONS_TEXT = 'Vocal Split Mode Options'
1511
+ VOCAL_SPLIT_OPTIONS_TEXT = 'Vocal Split Options'
1512
+ VOLUME_COMPENSATION_TEXT = 'Volume Compensation'
1513
+ VR_51_MODEL_TEXT = 'VR 5.1 Model'
1514
+ VR_ARCH_TEXT = 'VR Arch'
1515
+ WAV_TYPE_TEXT = 'Wav Type'
1516
+ CUDA_NUM_TEXT = 'GPU Device'
1517
+ WINDOW_SIZE_TEXT = 'Window Size'
1518
+ YES_TEXT = 'Yes'
1519
+ VERIFY_INPUTS_TEXT = 'Verify Inputs'
1520
+ AUDIO_INPUT_TOTAL_TEXT = 'Audio Input Total'
1521
+ MDX23C_ONLY_OPTIONS_TEXT = 'MDXNET23 Only Options'
1522
+ PROCESS_STARTING_TEXT = 'Process starting... '
1523
+ MISSING_MESS_TEXT = 'is missing or currupted.'
1524
+ SIMILAR_TEXT = "are the same."
1525
+ LOADING_VERSION_INFO_TEXT = 'Loading version information...'
1526
+ CHECK_FOR_UPDATES_TEXT = 'Check for Updates'
1527
+ INFO_UNAVAILABLE_TEXT = "Information unavailable."
1528
+ UPDATE_CONFIRMATION_TEXT = 'Are you sure you want to continue?\n\nThe application will need to be restarted.\n'
1529
+ BROKEN_OR_INCOM_TEXT = 'Broken or Incompatible File(s) Removed. Check Error Log for details.'
1530
+ BMAC_UVR_TEXT = 'UVR \"Buy Me a Coffee\" Link'
1531
+ MDX_MENU_WAR_TEXT = '(Leave this setting as is if you are unsure.)'
1532
+ NO_FILES_TEXT = 'No Files'
1533
+ CHOOSE_INPUT_TEXT = 'Choose Input'
1534
+ OPEN_INPUT_DIR_TEXT = 'Open Input Directory'
1535
+ BATCH_PROCESS_MENU_TEXT = 'Batch Process Menu'
1536
+ TEMP_FILE_DELETION_TEXT = 'Temp File Deletion'
1537
+ VOCAL_SPLITTER_OPTIONS_TEXT = 'Vocal Splitter Options'
1538
+ WAVEFORM_ENSEMBLE_TEXT = 'Waveform Ensemble'
1539
+ SELECT_INPUT_TEXT = 'Select Input'
1540
+ SELECT_OUTPUT_TEXT = 'Select Output'
1541
+ TIME_CORRECTION_TEXT = 'Time Correction'
1542
+ UVR_LIS_INFO_TEXT = 'UVR License Information'
1543
+ ADDITIONAL_RES_CREDITS_TEXT = 'Additional Resources & Credits'
1544
+ SAVE_INST_MIXTURE_TEXT = 'Save Instrumental Mixture'
1545
+ DOWNLOAD_UPDATE_IN_APP_TEXT = 'Download Update in Application'
1546
+ WAVE_TYPE_TEXT = 'WAVE TYPE'
1547
+ OPEN_LINK_TO_MODEL_TEXT = "Open Link to Model"
1548
+ OPEN_MODEL_DIRECTORY = "Open Model Directory"
1549
+ SELECTED_MODEL_PLACE_PATH_TEXT = 'Selected Model Placement Path'
1550
+ IS_INVERSE_STEM_TEXT = "Is Inverse Stem"
1551
+ INPUT_STEM_NAME_TEXT = "Input Stem Name"
1552
+ INPUT_UNIQUE_STEM_NAME_TEXT = "Input Unique Stem Name"
1553
+ DONE_MENU_TEXT = "Done"
1554
+ OK_TEXT = "Ok"
1555
+ ENSEMBLE_WARNING_NOT_ENOUGH_SHORT_TEXT = "Not Enough Models"
1556
+ ENSEMBLE_WARNING_NOT_ENOUGH_TEXT = "You must select 2 or more models to save an ensemble."
1557
+ NOT_ENOUGH_ERROR_TEXT = "Not enough files to process.\n"
1558
+ INVALID_FOLDER_ERROR_TEXT = 'Invalid Folder', 'Your given export path is not a valid folder!'
1559
+
1560
+ GET_DL_VIP_CODE_TEXT = ("Obtain codes by visiting one of the following links below."
1561
+ "\nFrom there you can donate, pledge, "
1562
+ "or just obatain the code!\n (Donations are not required to obtain VIP code)")
1563
+ CONFIRM_RESTART_TEXT = 'Restart Confirmation', 'This will restart the application and halt any running processes. Your current settings will be saved. \n\n Are you sure you wish to continue?'
1564
+ ERROR_LOADING_FILE_TEXT = 'Error Loading the Following File', 'Raw Error Details'
1565
+ LOADING_MODEL_TEXT = 'Loading model'
1566
+ FULL_APP_SET_TEXT = 'Full Application Settings'
1567
+ PROCESS_STARTING_TEXT = 'Process starting... '
1568
+ PROCESS_STOPPED_BY_USER = '\n\nProcess stopped by user.'
1569
+ NEW_UPDATE_FOUND_TEXT = lambda version:f"\n\nNew Update Found: {version}\n\nClick the update button in the \"Settings\" menu to download and install!"
1570
+ ROLL_BACK_TEXT = 'Click Here to Roll Back'
1571
+
1572
+ def secondary_stem(stem:str):
1573
+ """Determines secondary stem"""
1574
+
1575
+ stem = stem if stem else NO_STEM
1576
+
1577
+ if stem in STEM_PAIR_MAPPER.keys():
1578
+ for key, value in STEM_PAIR_MAPPER.items():
1579
+ if stem in key:
1580
+ secondary_stem = value
1581
+ else:
1582
+ secondary_stem = stem.replace(NO_STEM, "") if NO_STEM in stem else f"{NO_STEM}{stem}"
1583
+
1584
+ return secondary_stem
gui_data/cr_text.txt ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Most Recent Changes:
2
+
3
+ Patch UVR_Patch_10_6_23_4_27:
4
+ ~ MPS for MacOS is now compatible with all MDX-Net & VR Arch models!
5
+ ~ Fixed memory issue with MDX23C models.
6
+ ~ Added the ability to choose GPU device to process tracks.
7
+ ~ Fixed a few graphical bugs.
8
+
9
+ Other Changes:
10
+ ~ Added "Vocal Split Mode," a chain ensemble that utilizes karaoke & BVE
11
+ (backing vocal extractor) models to split vocals into lead vocal and backing vocal
12
+ stems.
13
+ ~ Updated "Audio Tools" to include "Matchering".
14
+ ~ The "Align Tool" has been improved to line up inputs even when there are differences
15
+ in timing, similar to what Utagoe does.
16
+ ~ Integrated a right-click menu for entries in the batch process window specifically
17
+ for Matchering & Align tool.
18
+ ~ Introduced a right-click option for Matchering & Align tool on the main window to
19
+ directly access input folders.
20
+ ~ Addressed the anomaly where models would self-select in dropdowns.
21
+ ~ This seemed related to having a multitude of models.
22
+ ~ Revamped the dropdown menus entirely.
23
+ ~ Fixed splash screen closing issue.
24
+ ~ Demucs Pre-process Model settings properly reset to default
25
+ ~ The MP3 encoder now uses LAME.
26
+ ~ Resolved the problem of using the "MDX23C_D1581" model as a pre-process model for
27
+ Demucs conversions were leading to a "Key Error: 'Primary Stem'" error.
28
+ ~ Rectified the "missing file" error that was popping up when trying to convert using
29
+ the Demucs v2 Demucs model.
30
+ ~ Solved the problem in Demucs using the htdemucs_6s multi-stem model: when certain
31
+ stems were selected and pitch-shift conversion was used, a KeyError prevented
32
+ the stem from being saved. The stem separation now works correctly, and files are
33
+ saved to the intended destination.
34
+ ~ Adjusted the "Reset All Settings To Default" function to ensure it also resets the
35
+ text of the Sample Mode slider.
36
+ ~ Made corrections so that if you activate Model Test Mode and then restart the app,
37
+ the "Accept Any Input" option will not be inadvertently enabled.
38
+ ~ Addressed the ensemble issue with MDX23C models.
39
+ ~ Updated the "Select Stems" menu for 2 stem models under MDX23C.
40
+ ~ Revamped the MDXC23 Overlap Menu.
41
+ ~ Provided an option to decide on using the VR denoise model in MDX-NET.
42
+ ~ Introduced a selection menu with choices: "None," "Standard," and "Denoise Model."
43
+ ~ Added an option to automatically deverb vocals for Demucs & MDX models. (does not
44
+ work in ensemble mode at this time)
45
+ ~ Updated the Download Center to feature MDX23C models.
46
+ ~ Fixed the glitch where ensemble final outputs weren't saved.
47
+ ~ Enhanced align/matchering outputs to support MP3 and FLAC.
48
+ ~ Refined tooltips for "Overlap" and "Segment Default."
49
+ ~ Ensured the download lists refresh properly even when there's no internet.
50
+ ~ Many more fixes.
51
+
52
+ Resources & Credits:
53
+
54
+ ----------------------------------------------------------------------------------------------
55
+ Name:
56
+ ZFTurbo
57
+
58
+ Contribution:
59
+ ~ Created the weights for the new MDX23C models.
60
+ ~ Trained the new MDX23C models.
61
+
62
+ Resources:
63
+ ~ These models are available to use online. See link details below:
64
+ ~ https://mvsep.com
65
+ ~ Separation type -> MDX23C (vocals, instrumental)
66
+ ~ Vocal model type -> 8K FFT, Full Band (SDR vocals: 10.17, SDR instrum: 16.48)
67
+ ----------------------------------------------------------------------------------------------
68
+ Name:
69
+ Bas Curtiz
70
+
71
+ Contribution:
72
+ ~ Conducted thorough testing and produced detailed tutorials on SDR evaluation.
73
+ ~ Implemented systematic SDR quality assessments.
74
+ ~ Authored Google Document detailing the optimal models for each genre.
75
+ ~ Authored Google Document capturing UVR frequency cutoffs and time-elapsed comparisons between GPU and CPU performances.
76
+ ~ Authored a Google Document consolidating best practices, tips, and tricks.
77
+
78
+ Resources:
79
+ ~ SDR Per Genre Breakdown
80
+ ~ UVR Model Frequency Cutoff Chart
81
+ ~ Tips & Tricks
82
+ ~ Link: (See "Discord Resources Link" below)
83
+ ----------------------------------------------------------------------------------------------
84
+ Name:
85
+ deton24
86
+
87
+ Contribution:
88
+ ~ SDR quality checks.
89
+ ~ Authored a Google Document that serves as an all-in-one guide to source separation.
90
+
91
+ Resources:
92
+ ~ All-in-One Source Separation Guide
93
+ ~ Link: (See "Discord Resources Link" below)
94
+ ----------------------------------------------------------------------------------------------
95
+ Name:
96
+ dca100fb8
97
+
98
+ Contribution:
99
+ ~ Testing
100
+ ~ Bug reporting
101
+ ~ Suggestions
102
+ ----------------------------------------------------------------------------------------------
103
+
104
+ Discord Resources Link: https://discord.com/channels/708579735583588363/1153421553694953482
gui_data/error_handling.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import traceback
3
+
4
+ CUDA_MEMORY_ERROR = "CUDA out of memory"
5
+ CUDA_RUNTIME_ERROR = "CUDNN error executing cudnnSetTensorNdDescriptor"
6
+ DEMUCS_MODEL_MISSING_ERROR = "is neither a single pre-trained model or a bag of models."
7
+ ENSEMBLE_MISSING_MODEL_ERROR = "local variable \'enseExport\' referenced before assignment"
8
+ FFMPEG_MISSING_ERROR = """audioread\__init__.py", line 116, in audio_open"""
9
+ FILE_MISSING_ERROR = "FileNotFoundError"
10
+ MDX_MEMORY_ERROR = "onnxruntime::CudaCall CUDA failure 2: out of memory"
11
+ MDX_MODEL_MISSING = "[ONNXRuntimeError] : 3 : NO_SUCHFILE"
12
+ MDX_MODEL_SETTINGS_ERROR = "Got invalid dimensions for input"
13
+ MDX_RUNTIME_ERROR = "onnxruntime::BFCArena::AllocateRawInternal"
14
+ MODULE_ERROR = "ModuleNotFoundError"
15
+ WINDOW_SIZE_ERROR = "h1_shape[3] must be greater than h2_shape[3]"
16
+ SF_WRITE_ERROR = "sf.write"
17
+ SYSTEM_MEMORY_ERROR = "DefaultCPUAllocator: not enough memory"
18
+ MISSING_MODEL_ERROR = "'NoneType\' object has no attribute \'model_basename\'"
19
+ ARRAY_SIZE_ERROR = "ValueError: \"array is too big; `arr.size * arr.dtype.itemsize` is larger than the maximum possible size.\""
20
+ GPU_INCOMPATIBLE_ERROR = "no kernel image is available for execution on the device"
21
+ SELECT_CORRECT_GPU = "CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect."
22
+
23
+ CONTACT_DEV = 'If this error persists, please contact the developers with the error details.'
24
+
25
+ ERROR_MAPPER = {
26
+ CUDA_MEMORY_ERROR:
27
+ ('The application was unable to allocate enough GPU memory to use this model. ' +
28
+ 'Please close any GPU intensive applications and try again.\n' +
29
+ 'If the error persists, your GPU might not be supported.') ,
30
+ CUDA_RUNTIME_ERROR:
31
+ (f'Your PC cannot process this audio file with the segment size selected. Please lower the segment size and try again.\n\n{CONTACT_DEV}'),
32
+ DEMUCS_MODEL_MISSING_ERROR:
33
+ ('The selected Demucs model is missing. ' +
34
+ 'Please download the model or make sure it is in the correct directory.'),
35
+ ENSEMBLE_MISSING_MODEL_ERROR:
36
+ ('The application was unable to locate a model you selected for this ensemble.\n\n' +
37
+ 'Please do the following to use all compatible models:\n\n1. Navigate to the \"Updates\" tab in the Help Guide.\n2. Download and install the model expansion pack.\n3. Then try again.\n\n' +
38
+ 'If the error persists, please verify all models are present.'),
39
+ FFMPEG_MISSING_ERROR:
40
+ ('The input file type is not supported or FFmpeg is missing. Please select a file type supported by FFmpeg and try again. ' +
41
+ 'If FFmpeg is missing or not installed, you will only be able to process \".wav\" files until it is available on this system. ' +
42
+ f'See the \"More Info\" tab in the Help Guide.\n\n{CONTACT_DEV}'),
43
+ FILE_MISSING_ERROR:
44
+ (f'Missing file error raised. Please address the error and try again.\n\n{CONTACT_DEV}'),
45
+ MDX_MEMORY_ERROR:
46
+ ('The application was unable to allocate enough GPU memory to use this model.\n\n' +
47
+ 'Please do the following:\n\n1. Close any GPU intensive applications.\n2. Lower the set segment size.\n3. Then try again.\n\n' +
48
+ 'If the error persists, your GPU might not be supported.'),
49
+ MDX_MODEL_MISSING:
50
+ ('The application could not detect this MDX-Net model on your system. ' +
51
+ 'Please make sure all the models are present in the correct directory.\n\n' +
52
+ 'If the error persists, please reinstall application or contact the developers.'),
53
+ MDX_RUNTIME_ERROR:
54
+ ('The application was unable to allocate enough GPU memory to use this model.\n\n' +
55
+ 'Please do the following:\n\n1. Close any GPU intensive applications.\n2. Lower the set segment size.\n3. Then try again.\n\n' +
56
+ 'If the error persists, your GPU might not be supported.'),
57
+ WINDOW_SIZE_ERROR:
58
+ ('Invalid window size.\n\n' +
59
+ 'The chosen window size is likely not compatible with this model. Please select a different size and try again.'),
60
+ SF_WRITE_ERROR:
61
+ ('Could not write audio file.\n\n' +
62
+ 'This could be due to one of the following:\n\n1. Low storage on target device.\n2. The export directory no longer exists.\n3. A system permissions issue.'),
63
+ SYSTEM_MEMORY_ERROR:
64
+ ('The application was unable to allocate enough system memory to use this model.\n\n' +
65
+ 'Please do the following:\n\n1. Restart this application.\n2. Ensure any CPU intensive applications are closed.\n3. Then try again.\n\n' +
66
+ 'Please Note: Intel Pentium and Intel Celeron processors do not work well with this application.\n\n' +
67
+ 'If the error persists, the system may not have enough RAM, or your CPU might not be supported.'),
68
+ MISSING_MODEL_ERROR:
69
+ ('Model Missing: The application was unable to locate the chosen model.\n\n' +
70
+ 'If the error persists, please verify any selected models are present.'),
71
+ GPU_INCOMPATIBLE_ERROR:
72
+ ('This process is not compatible with your GPU.\n\n' +
73
+ 'Please uncheck \"GPU Conversion\" and try again'),
74
+ SELECT_CORRECT_GPU:
75
+ ('Make sure you\'ve chosen the correct GPU.\n\n'
76
+ 'Go to the "Settings Guide", click the "Additional Settings" tab and select the correct GPU device.'),
77
+ ARRAY_SIZE_ERROR:
78
+ ('The application was not able to process the given audiofile. Please convert the audiofile to another format and try again.'),
79
+ }
80
+
81
+ def error_text(process_method, exception):
82
+
83
+ traceback_text = ''.join(traceback.format_tb(exception.__traceback__))
84
+ message = f'{type(exception).__name__}: "{exception}"\nTraceback Error: "\n{traceback_text}"\n'
85
+ error_message = f'\n\nRaw Error Details:\n\n{message}\nError Time Stamp [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]\n'
86
+ process = f'Last Error Received:\n\nProcess: {process_method}\n\n'
87
+
88
+ for error_type, full_text in ERROR_MAPPER.items():
89
+ if error_type in message:
90
+ final_message = full_text
91
+ break
92
+ else:
93
+ final_message = (CONTACT_DEV)
94
+
95
+ return f"{process}{final_message}{error_message}"
96
+
97
+ def error_dialouge(exception):
98
+
99
+ error_name = f'{type(exception).__name__}'
100
+ traceback_text = ''.join(traceback.format_tb(exception.__traceback__))
101
+ message = f'{error_name}: "{exception}"\n{traceback_text}"'
102
+
103
+ for error_type, full_text in ERROR_MAPPER.items():
104
+ if error_type in message:
105
+ final_message = full_text
106
+ break
107
+ else:
108
+ final_message = (f'An Error Occurred: {error_name}\n\n{CONTACT_DEV}')
109
+
110
+ return final_message
gui_data/fail_chime.wav ADDED
Binary file (385 kB). View file
 
gui_data/fonts/Montserrat/Montserrat.ttf ADDED
Binary file (263 kB). View file
 
gui_data/fonts/centurygothic/GOTHIC.ttf ADDED
Binary file (138 kB). View file
 
gui_data/fonts/other/own_font_goes_here.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0
gui_data/img/File.png ADDED
gui_data/img/GUI-Icon.ico ADDED