Spaces:
Running
Running
Update audio_separator/separator/separator.py
Browse files- audio_separator/separator/separator.py +290 -736
audio_separator/separator/separator.py
CHANGED
|
@@ -22,68 +22,34 @@ import onnxruntime as ort
|
|
| 22 |
from tqdm import tqdm
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
class Separator:
|
| 26 |
"""
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
It also handles logging, normalization, and output formatting of the separated audio stems.
|
| 31 |
-
|
| 32 |
-
The actual separation task is handled by one of the architecture-specific classes in the `architectures` module;
|
| 33 |
-
this class is responsible for initialising logging, configuring hardware acceleration, loading the model,
|
| 34 |
-
initiating the separation process and passing outputs back to the caller.
|
| 35 |
-
|
| 36 |
-
Common Attributes:
|
| 37 |
-
log_level (int): The logging level.
|
| 38 |
-
log_formatter (logging.Formatter): The logging formatter.
|
| 39 |
-
model_file_dir (str): The directory where model files are stored.
|
| 40 |
-
output_dir (str): The directory where output files will be saved.
|
| 41 |
-
output_format (str): The format of the output audio file.
|
| 42 |
-
output_bitrate (str): The bitrate of the output audio file.
|
| 43 |
-
amplification_threshold (float): The threshold for audio amplification.
|
| 44 |
-
normalization_threshold (float): The threshold for audio normalization.
|
| 45 |
-
output_single_stem (str): Option to output a single stem.
|
| 46 |
-
invert_using_spec (bool): Flag to invert using spectrogram.
|
| 47 |
-
sample_rate (int): The sample rate of the audio.
|
| 48 |
-
use_soundfile (bool): Use soundfile for audio writing, can solve OOM issues.
|
| 49 |
-
use_autocast (bool): Flag to use PyTorch autocast for faster inference.
|
| 50 |
-
|
| 51 |
-
MDX Architecture Specific Attributes:
|
| 52 |
-
hop_length (int): The hop length for STFT.
|
| 53 |
-
segment_size (int): The segment size for processing.
|
| 54 |
-
overlap (float): The overlap between segments.
|
| 55 |
-
batch_size (int): The batch size for processing.
|
| 56 |
-
enable_denoise (bool): Flag to enable or disable denoising.
|
| 57 |
-
|
| 58 |
-
VR Architecture Specific Attributes & Defaults:
|
| 59 |
-
batch_size: 16
|
| 60 |
-
window_size: 512
|
| 61 |
-
aggression: 5
|
| 62 |
-
enable_tta: False
|
| 63 |
-
enable_post_process: False
|
| 64 |
-
post_process_threshold: 0.2
|
| 65 |
-
high_end_process: False
|
| 66 |
-
|
| 67 |
-
Demucs Architecture Specific Attributes & Defaults:
|
| 68 |
-
segment_size: "Default"
|
| 69 |
-
shifts: 2
|
| 70 |
-
overlap: 0.25
|
| 71 |
-
segments_enabled: True
|
| 72 |
-
|
| 73 |
-
MDXC Architecture Specific Attributes & Defaults:
|
| 74 |
-
segment_size: 256
|
| 75 |
-
override_model_segment_size: False
|
| 76 |
-
batch_size: 1
|
| 77 |
-
overlap: 8
|
| 78 |
-
pitch_shift: 0
|
| 79 |
"""
|
| 80 |
-
|
| 81 |
def __init__(
|
| 82 |
self,
|
| 83 |
log_level=logging.INFO,
|
| 84 |
-
log_formatter=None,
|
| 85 |
model_file_dir="/tmp/audio-separator-models/",
|
| 86 |
-
output_dir=
|
| 87 |
output_format="WAV",
|
| 88 |
output_bitrate=None,
|
| 89 |
normalization_threshold=0.9,
|
|
@@ -91,8 +57,8 @@ class Separator:
|
|
| 91 |
output_single_stem=None,
|
| 92 |
invert_using_spec=False,
|
| 93 |
sample_rate=44100,
|
| 94 |
-
use_soundfile=
|
| 95 |
-
use_autocast=
|
| 96 |
use_directml=False,
|
| 97 |
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
|
| 98 |
vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
|
|
@@ -100,639 +66,271 @@ class Separator:
|
|
| 100 |
mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
|
| 101 |
info_only=False,
|
| 102 |
):
|
| 103 |
-
"""Initialize the separator."""
|
| 104 |
self.logger = logging.getLogger(__name__)
|
| 105 |
self.logger.setLevel(log_level)
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
self.log_handler = logging.StreamHandler()
|
| 110 |
-
|
| 111 |
-
if self.log_formatter is None:
|
| 112 |
-
self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
|
| 113 |
-
|
| 114 |
-
self.log_handler.setFormatter(self.log_formatter)
|
| 115 |
-
|
| 116 |
if not self.logger.hasHandlers():
|
| 117 |
-
self.logger.addHandler(
|
| 118 |
-
|
| 119 |
-
# Filter out noisy warnings from PyTorch for users who don't care about them
|
| 120 |
-
if log_level > logging.DEBUG:
|
| 121 |
-
warnings.filterwarnings("ignore")
|
| 122 |
|
| 123 |
-
#
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
if output_dir is None:
|
| 129 |
-
output_dir = os.getcwd()
|
| 130 |
-
if not info_only:
|
| 131 |
-
self.logger.info("Output directory not specified. Using current working directory.")
|
| 132 |
-
|
| 133 |
-
self.output_dir = output_dir
|
| 134 |
-
|
| 135 |
-
# Check for environment variable to override model_file_dir
|
| 136 |
-
env_model_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR")
|
| 137 |
-
if env_model_dir:
|
| 138 |
-
self.model_file_dir = env_model_dir
|
| 139 |
-
self.logger.info(f"Using model directory from AUDIO_SEPARATOR_MODEL_DIR env var: {self.model_file_dir}")
|
| 140 |
-
if not os.path.exists(self.model_file_dir):
|
| 141 |
-
raise FileNotFoundError(f"The specified model directory does not exist: {self.model_file_dir}")
|
| 142 |
-
else:
|
| 143 |
-
self.logger.info(f"Using model directory from model_file_dir parameter: {model_file_dir}")
|
| 144 |
-
self.model_file_dir = model_file_dir
|
| 145 |
-
|
| 146 |
-
# Create the model directory if it does not exist
|
| 147 |
-
os.makedirs(self.model_file_dir, exist_ok=True)
|
| 148 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
| 149 |
-
|
| 150 |
-
self.output_format = output_format
|
| 151 |
self.output_bitrate = output_bitrate
|
| 152 |
-
|
| 153 |
-
if self.output_format is None:
|
| 154 |
-
self.output_format = "WAV"
|
| 155 |
-
|
| 156 |
self.normalization_threshold = normalization_threshold
|
| 157 |
-
if normalization_threshold <= 0 or normalization_threshold > 1:
|
| 158 |
-
raise ValueError("The normalization_threshold must be greater than 0 and less than or equal to 1.")
|
| 159 |
-
|
| 160 |
self.amplification_threshold = amplification_threshold
|
| 161 |
-
if amplification_threshold < 0 or amplification_threshold > 1:
|
| 162 |
-
raise ValueError("The amplification_threshold must be greater than or equal to 0 and less than or equal to 1.")
|
| 163 |
-
|
| 164 |
self.output_single_stem = output_single_stem
|
| 165 |
-
if output_single_stem is not None:
|
| 166 |
-
self.logger.debug(f"Single stem output requested, so only one output file ({output_single_stem}) will be written")
|
| 167 |
-
|
| 168 |
self.invert_using_spec = invert_using_spec
|
| 169 |
-
|
| 170 |
-
self.logger.debug(f"Secondary step will be inverted using spectogram rather than waveform. This may improve quality but is slightly slower.")
|
| 171 |
-
|
| 172 |
-
try:
|
| 173 |
-
self.sample_rate = int(sample_rate)
|
| 174 |
-
if self.sample_rate <= 0:
|
| 175 |
-
raise ValueError(f"The sample rate setting is {self.sample_rate} but it must be a non-zero whole number.")
|
| 176 |
-
if self.sample_rate > 12800000:
|
| 177 |
-
raise ValueError(f"The sample rate setting is {self.sample_rate}. Enter something less ambitious.")
|
| 178 |
-
except ValueError:
|
| 179 |
-
raise ValueError("The sample rate must be a non-zero whole number. Please provide a valid integer.")
|
| 180 |
-
|
| 181 |
self.use_soundfile = use_soundfile
|
| 182 |
self.use_autocast = use_autocast
|
| 183 |
self.use_directml = use_directml
|
| 184 |
-
|
| 185 |
-
# These are parameters which users may want to configure so we expose them to the top-level Separator class,
|
| 186 |
-
# even though they are specific to a single model architecture
|
| 187 |
self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
-
|
| 194 |
-
self.
|
|
|
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
self.model_is_uvr_vip = False
|
| 197 |
self.model_friendly_name = None
|
| 198 |
|
| 199 |
if not info_only:
|
| 200 |
-
self.
|
| 201 |
|
| 202 |
def setup_accelerated_inferencing_device(self):
|
| 203 |
-
"""
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
self.setup_torch_device(system_info)
|
| 210 |
-
|
| 211 |
-
def get_system_info(self):
|
| 212 |
-
"""
|
| 213 |
-
This method logs the system information, including the operating system, CPU archutecture and Python version
|
| 214 |
-
"""
|
| 215 |
-
os_name = platform.system()
|
| 216 |
-
os_version = platform.version()
|
| 217 |
-
self.logger.info(f"Operating System: {os_name} {os_version}")
|
| 218 |
-
|
| 219 |
-
system_info = platform.uname()
|
| 220 |
-
self.logger.info(f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}")
|
| 221 |
-
|
| 222 |
-
python_version = platform.python_version()
|
| 223 |
-
self.logger.info(f"Python Version: {python_version}")
|
| 224 |
-
|
| 225 |
-
pytorch_version = torch.__version__
|
| 226 |
-
self.logger.info(f"PyTorch Version: {pytorch_version}")
|
| 227 |
-
return system_info
|
| 228 |
-
|
| 229 |
-
def check_ffmpeg_installed(self):
|
| 230 |
-
"""
|
| 231 |
-
This method checks if ffmpeg is installed and logs its version.
|
| 232 |
-
"""
|
| 233 |
-
try:
|
| 234 |
-
ffmpeg_version_output = subprocess.check_output(["ffmpeg", "-version"], text=True)
|
| 235 |
-
first_line = ffmpeg_version_output.splitlines()[0]
|
| 236 |
-
self.logger.info(f"FFmpeg installed: {first_line}")
|
| 237 |
-
except FileNotFoundError:
|
| 238 |
-
self.logger.error("FFmpeg is not installed. Please install FFmpeg to use this package.")
|
| 239 |
-
# Raise an exception if this is being run by a user, as ffmpeg is required for pydub to write audio
|
| 240 |
-
# but if we're just running unit tests in CI, no reason to throw
|
| 241 |
-
if "PYTEST_CURRENT_TEST" not in os.environ:
|
| 242 |
-
raise
|
| 243 |
-
|
| 244 |
-
def log_onnxruntime_packages(self):
|
| 245 |
-
"""
|
| 246 |
-
This method logs the ONNX Runtime package versions, including the GPU and Silicon packages if available.
|
| 247 |
-
"""
|
| 248 |
-
onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
|
| 249 |
-
onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon")
|
| 250 |
-
onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
|
| 251 |
-
onnxruntime_dml_package = self.get_package_distribution("onnxruntime-directml")
|
| 252 |
-
|
| 253 |
-
if onnxruntime_gpu_package is not None:
|
| 254 |
-
self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}")
|
| 255 |
-
if onnxruntime_silicon_package is not None:
|
| 256 |
-
self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}")
|
| 257 |
-
if onnxruntime_cpu_package is not None:
|
| 258 |
-
self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}")
|
| 259 |
-
if onnxruntime_dml_package is not None:
|
| 260 |
-
self.logger.info(f"ONNX Runtime DirectML package installed with version: {onnxruntime_dml_package.version}")
|
| 261 |
-
|
| 262 |
-
def setup_torch_device(self, system_info):
|
| 263 |
-
"""
|
| 264 |
-
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
|
| 265 |
-
"""
|
| 266 |
-
hardware_acceleration_enabled = False
|
| 267 |
-
ort_providers = ort.get_available_providers()
|
| 268 |
-
has_torch_dml_installed = self.get_package_distribution("torch_directml")
|
| 269 |
-
|
| 270 |
-
self.torch_device_cpu = torch.device("cpu")
|
| 271 |
-
|
| 272 |
-
if torch.cuda.is_available():
|
| 273 |
-
self.configure_cuda(ort_providers)
|
| 274 |
-
hardware_acceleration_enabled = True
|
| 275 |
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
|
| 276 |
-
self.configure_mps(ort_providers)
|
| 277 |
-
hardware_acceleration_enabled = True
|
| 278 |
-
elif self.use_directml and has_torch_dml_installed:
|
| 279 |
-
import torch_directml
|
| 280 |
-
if torch_directml.is_available():
|
| 281 |
-
self.configure_dml(ort_providers)
|
| 282 |
-
hardware_acceleration_enabled = True
|
| 283 |
-
|
| 284 |
-
if not hardware_acceleration_enabled:
|
| 285 |
-
self.logger.info("No hardware acceleration could be configured, running in CPU mode")
|
| 286 |
-
self.torch_device = self.torch_device_cpu
|
| 287 |
-
self.onnx_execution_provider = ["CPUExecutionProvider"]
|
| 288 |
-
|
| 289 |
-
def configure_cuda(self, ort_providers):
|
| 290 |
-
"""
|
| 291 |
-
This method configures the CUDA device for PyTorch and ONNX Runtime, if available.
|
| 292 |
-
"""
|
| 293 |
-
self.logger.info("CUDA is available in Torch, setting Torch device to CUDA")
|
| 294 |
-
self.torch_device = torch.device("cuda")
|
| 295 |
-
if "CUDAExecutionProvider" in ort_providers:
|
| 296 |
-
self.logger.info("ONNXruntime has CUDAExecutionProvider available, enabling acceleration")
|
| 297 |
-
self.onnx_execution_provider = ["CUDAExecutionProvider"]
|
| 298 |
-
else:
|
| 299 |
-
self.logger.warning("CUDAExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
| 300 |
-
|
| 301 |
-
def configure_mps(self, ort_providers):
|
| 302 |
-
"""
|
| 303 |
-
This method configures the Apple Silicon MPS/CoreML device for PyTorch and ONNX Runtime, if available.
|
| 304 |
-
"""
|
| 305 |
-
self.logger.info("Apple Silicon MPS/CoreML is available in Torch and processor is ARM, setting Torch device to MPS")
|
| 306 |
-
self.torch_device_mps = torch.device("mps")
|
| 307 |
-
|
| 308 |
-
self.torch_device = self.torch_device_mps
|
| 309 |
-
|
| 310 |
-
if "CoreMLExecutionProvider" in ort_providers:
|
| 311 |
-
self.logger.info("ONNXruntime has CoreMLExecutionProvider available, enabling acceleration")
|
| 312 |
-
self.onnx_execution_provider = ["CoreMLExecutionProvider"]
|
| 313 |
else:
|
| 314 |
-
self.logger.
|
| 315 |
-
|
| 316 |
-
def configure_dml(self, ort_providers):
|
| 317 |
-
"""
|
| 318 |
-
This method configures the DirectML device for PyTorch and ONNX Runtime, if available.
|
| 319 |
-
"""
|
| 320 |
-
import torch_directml
|
| 321 |
-
self.logger.info("DirectML is available in Torch, setting Torch device to DirectML")
|
| 322 |
-
self.torch_device_dml = torch_directml.device()
|
| 323 |
-
self.torch_device = self.torch_device_dml
|
| 324 |
-
|
| 325 |
-
if "DmlExecutionProvider" in ort_providers:
|
| 326 |
-
self.logger.info("ONNXruntime has DmlExecutionProvider available, enabling acceleration")
|
| 327 |
-
self.onnx_execution_provider = ["DmlExecutionProvider"]
|
| 328 |
-
else:
|
| 329 |
-
self.logger.warning("DmlExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
| 330 |
-
|
| 331 |
-
def get_package_distribution(self, package_name):
|
| 332 |
-
"""
|
| 333 |
-
This method returns the package distribution for a given package name if installed, or None otherwise.
|
| 334 |
-
"""
|
| 335 |
-
try:
|
| 336 |
-
return metadata.distribution(package_name)
|
| 337 |
-
except metadata.PackageNotFoundError:
|
| 338 |
-
self.logger.debug(f"Python package: {package_name} not installed")
|
| 339 |
-
return None
|
| 340 |
|
| 341 |
def get_model_hash(self, model_path):
|
| 342 |
-
"""
|
| 343 |
-
|
| 344 |
-
"""
|
| 345 |
-
self.logger.debug(f"Calculating hash of model file {model_path}")
|
| 346 |
-
# Use the specific byte count from the original logic
|
| 347 |
-
BYTES_TO_HASH = 10000 * 1024 # 10,240,000 bytes
|
| 348 |
-
|
| 349 |
try:
|
| 350 |
-
file_size = os.path.getsize(model_path)
|
| 351 |
-
|
| 352 |
with open(model_path, "rb") as f:
|
|
|
|
| 353 |
if file_size < BYTES_TO_HASH:
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
else:
|
| 358 |
-
# Seek to the specific position before the end (from the beginning) and hash
|
| 359 |
-
seek_pos = file_size - BYTES_TO_HASH
|
| 360 |
-
self.logger.debug(f"File size {file_size} >= {BYTES_TO_HASH}, seeking to {seek_pos} and hashing remaining bytes.")
|
| 361 |
-
f.seek(seek_pos, io.SEEK_SET)
|
| 362 |
-
hash_value = hashlib.md5(f.read()).hexdigest()
|
| 363 |
-
|
| 364 |
-
# Log the calculated hash
|
| 365 |
-
self.logger.info(f"Hash of model file {model_path} is {hash_value}")
|
| 366 |
-
return hash_value
|
| 367 |
-
|
| 368 |
-
except FileNotFoundError:
|
| 369 |
-
self.logger.error(f"Model file not found at {model_path}")
|
| 370 |
-
raise # Re-raise the specific error
|
| 371 |
except Exception as e:
|
| 372 |
-
# Catch other potential errors (e.g., permissions, other IOErrors)
|
| 373 |
self.logger.error(f"Error calculating hash for {model_path}: {e}")
|
| 374 |
-
raise
|
| 375 |
|
| 376 |
def download_file_if_not_exists(self, url, output_path):
|
| 377 |
-
"""
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
if os.path.isfile(output_path):
|
| 382 |
-
self.logger.debug(f"File already exists at {output_path}, skipping download")
|
| 383 |
return
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
response = requests.get(url, stream=True, timeout=300)
|
| 387 |
-
|
| 388 |
if response.status_code == 200:
|
| 389 |
-
|
| 390 |
-
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
| 391 |
-
|
| 392 |
-
with open(output_path, "wb") as f:
|
| 393 |
for chunk in response.iter_content(chunk_size=8192):
|
| 394 |
-
progress_bar.update(len(chunk))
|
| 395 |
f.write(chunk)
|
| 396 |
-
|
| 397 |
else:
|
| 398 |
-
raise RuntimeError(f"Failed to download
|
| 399 |
|
| 400 |
def list_supported_model_files(self):
|
| 401 |
-
"""
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
"
|
| 411 |
-
|
| 412 |
-
"
|
| 413 |
-
"SDR": 10.6497,
|
| 414 |
-
"SIR": 20.3786,
|
| 415 |
-
"SAR": 10.692,
|
| 416 |
-
"ISR": 14.848
|
| 417 |
-
},
|
| 418 |
-
"instrumental": {
|
| 419 |
-
"SDR": 15.2149,
|
| 420 |
-
"SIR": 25.6075,
|
| 421 |
-
"SAR": 17.1363,
|
| 422 |
-
"ISR": 17.7893
|
| 423 |
-
}
|
| 424 |
},
|
| 425 |
-
"
|
| 426 |
-
|
| 427 |
-
]
|
| 428 |
-
}
|
| 429 |
},
|
| 430 |
-
"
|
| 431 |
-
"
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
"
|
| 435 |
-
"SDR": 11.2685,
|
| 436 |
-
"SIR": 21.257,
|
| 437 |
-
"SAR": 11.0359,
|
| 438 |
-
"ISR": 19.3753
|
| 439 |
-
},
|
| 440 |
-
"drums": {
|
| 441 |
-
"SDR": 13.235,
|
| 442 |
-
"SIR": 23.3053,
|
| 443 |
-
"SAR": 13.0313,
|
| 444 |
-
"ISR": 17.2889
|
| 445 |
-
},
|
| 446 |
-
"bass": {
|
| 447 |
-
"SDR": 9.72743,
|
| 448 |
-
"SIR": 19.5435,
|
| 449 |
-
"SAR": 9.20801,
|
| 450 |
-
"ISR": 13.5037
|
| 451 |
-
}
|
| 452 |
},
|
| 453 |
-
"
|
| 454 |
-
|
| 455 |
-
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th",
|
| 456 |
-
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th",
|
| 457 |
-
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th",
|
| 458 |
-
"https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml"
|
| 459 |
-
]
|
| 460 |
-
}
|
| 461 |
},
|
| 462 |
-
"
|
| 463 |
-
"
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
"vocals": {
|
| 467 |
-
"SDR": 11.9504,
|
| 468 |
-
"SIR": 23.1166,
|
| 469 |
-
"SAR": 12.093,
|
| 470 |
-
"ISR": 15.4782
|
| 471 |
-
},
|
| 472 |
-
"instrumental": {
|
| 473 |
-
"SDR": 16.3035,
|
| 474 |
-
"SIR": 26.6161,
|
| 475 |
-
"SAR": 18.5167,
|
| 476 |
-
"ISR": 18.3939
|
| 477 |
-
}
|
| 478 |
},
|
| 479 |
-
"
|
| 480 |
-
|
| 481 |
-
"model_2_stem_full_band_8k.yaml"
|
| 482 |
-
]
|
| 483 |
-
}
|
| 484 |
}
|
| 485 |
}
|
| 486 |
-
"""
|
| 487 |
-
download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
|
| 488 |
-
|
| 489 |
-
self.download_file_if_not_exists("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json", download_checks_path)
|
| 490 |
-
|
| 491 |
-
model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
|
| 492 |
-
self.logger.debug(f"UVR model download list loaded")
|
| 493 |
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
try:
|
| 497 |
-
with resources.open_text("audio_separator", "models-scores.json") as f:
|
| 498 |
-
model_scores = json.load(f)
|
| 499 |
-
self.logger.debug(f"Model scores loaded")
|
| 500 |
-
except json.JSONDecodeError as e:
|
| 501 |
-
self.logger.warning(f"Failed to load model scores: {str(e)}")
|
| 502 |
-
self.logger.warning("Continuing without model scores")
|
| 503 |
-
|
| 504 |
-
# Only show Demucs v4 models as we've only implemented support for v4
|
| 505 |
-
filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}
|
| 506 |
-
|
| 507 |
-
# Modified Demucs handling to use YAML files as identifiers and include download files
|
| 508 |
-
demucs_models = {}
|
| 509 |
-
for name, files in filtered_demucs_v4.items():
|
| 510 |
-
# Find the YAML file in the model files
|
| 511 |
-
yaml_file = next((filename for filename in files.keys() if filename.endswith(".yaml")), None)
|
| 512 |
-
if yaml_file:
|
| 513 |
-
model_score_data = model_scores.get(yaml_file, {})
|
| 514 |
-
demucs_models[name] = {
|
| 515 |
-
"filename": yaml_file,
|
| 516 |
-
"scores": model_score_data.get("median_scores", {}),
|
| 517 |
-
"stems": model_score_data.get("stems", []),
|
| 518 |
-
"target_stem": model_score_data.get("target_stem"),
|
| 519 |
-
"download_files": list(files.values()), # List of all download URLs/filenames
|
| 520 |
-
}
|
| 521 |
-
|
| 522 |
-
# Load the JSON file using importlib.resources
|
| 523 |
-
with resources.open_text("audio_separator", "models.json") as f:
|
| 524 |
-
audio_separator_models_list = json.load(f)
|
| 525 |
-
self.logger.debug(f"Audio-Separator model list loaded")
|
| 526 |
|
| 527 |
-
#
|
| 528 |
model_files_grouped_by_type = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
"VR": {
|
| 530 |
-
|
| 531 |
-
"filename":
|
| 532 |
-
"scores":
|
| 533 |
-
"stems":
|
| 534 |
-
"target_stem":
|
| 535 |
-
"download_files": [
|
| 536 |
-
}
|
| 537 |
-
for name, filename in {**model_downloads_list["vr_download_list"], **audio_separator_models_list["vr_download_list"]}.items()
|
| 538 |
},
|
| 539 |
-
"
|
| 540 |
-
|
| 541 |
-
"filename":
|
| 542 |
-
"scores": model_scores.get(
|
| 543 |
-
"stems": model_scores.get(
|
| 544 |
-
"target_stem": model_scores.get(
|
| 545 |
-
"download_files": [
|
| 546 |
-
|
| 547 |
-
|
|
|
|
|
|
|
| 548 |
},
|
| 549 |
-
"Demucs": demucs_models,
|
| 550 |
"MDXC": {
|
| 551 |
-
|
| 552 |
-
"filename":
|
| 553 |
-
"scores": model_scores.get(
|
| 554 |
-
"stems": model_scores.get(
|
| 555 |
-
"target_stem": model_scores.get(
|
| 556 |
-
"download_files":
|
|
|
|
|
|
|
|
|
|
| 557 |
}
|
| 558 |
-
|
| 559 |
-
**model_downloads_list["mdx23c_download_list"],
|
| 560 |
-
**model_downloads_list["mdx23c_download_vip_list"],
|
| 561 |
-
**model_downloads_list["roformer_download_list"],
|
| 562 |
-
**audio_separator_models_list["mdx23c_download_list"],
|
| 563 |
-
**audio_separator_models_list["roformer_download_list"],
|
| 564 |
-
}.items()
|
| 565 |
-
},
|
| 566 |
}
|
| 567 |
-
|
| 568 |
return model_files_grouped_by_type
|
| 569 |
|
| 570 |
def print_uvr_vip_message(self):
|
| 571 |
-
"""
|
| 572 |
-
This method prints a message to the user if they have downloaded a VIP model, reminding them to support Anjok07 on Patreon.
|
| 573 |
-
"""
|
| 574 |
if self.model_is_uvr_vip:
|
| 575 |
-
self.logger.warning(f"
|
| 576 |
-
self.logger.warning("If you are not already subscribed, please consider supporting the developer of UVR, Anjok07 by subscribing here: https://patreon.com/uvr")
|
| 577 |
|
| 578 |
def download_model_files(self, model_filename):
|
| 579 |
-
"""
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
"""
|
| 583 |
-
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
|
| 584 |
-
|
| 585 |
-
supported_model_files_grouped = self.list_supported_model_files()
|
| 586 |
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
|
| 587 |
vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
|
| 588 |
audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
|
| 589 |
|
| 590 |
yaml_config_filename = None
|
| 591 |
-
|
| 592 |
-
self.logger.debug(f"Searching for model_filename {model_filename} in supported_model_files_grouped")
|
| 593 |
-
|
| 594 |
-
# Iterate through model types (MDX, Demucs, MDXC)
|
| 595 |
-
for model_type, models in supported_model_files_grouped.items():
|
| 596 |
-
# Iterate through each model in this type
|
| 597 |
for model_friendly_name, model_info in models.items():
|
| 598 |
self.model_is_uvr_vip = "VIP" in model_friendly_name
|
| 599 |
model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
|
| 600 |
-
|
| 601 |
-
# Check if this model matches our target filename
|
| 602 |
if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
|
| 603 |
-
self.logger.debug(f"Found matching model: {model_friendly_name}")
|
| 604 |
self.model_friendly_name = model_friendly_name
|
| 605 |
self.print_uvr_vip_message()
|
| 606 |
-
|
| 607 |
-
# Download each required file for this model
|
| 608 |
for file_to_download in model_info["download_files"]:
|
| 609 |
-
# For URLs, extract just the filename portion
|
| 610 |
if file_to_download.startswith("http"):
|
| 611 |
filename = file_to_download.split("/")[-1]
|
| 612 |
download_path = os.path.join(self.model_file_dir, filename)
|
| 613 |
self.download_file_if_not_exists(file_to_download, download_path)
|
|
|
|
|
|
|
| 614 |
continue
|
| 615 |
-
|
| 616 |
download_path = os.path.join(self.model_file_dir, file_to_download)
|
| 617 |
-
|
| 618 |
-
# For MDXC models, handle YAML config files specially
|
| 619 |
-
if model_type == "MDXC" and file_to_download.endswith(".yaml"):
|
| 620 |
-
yaml_config_filename = file_to_download
|
| 621 |
-
try:
|
| 622 |
-
yaml_url = f"{model_repo_url_prefix}/mdx_model_data/mdx_c_configs/{file_to_download}"
|
| 623 |
-
self.download_file_if_not_exists(yaml_url, download_path)
|
| 624 |
-
except RuntimeError:
|
| 625 |
-
self.logger.debug("YAML config not found in UVR repo, trying audio-separator models repo...")
|
| 626 |
-
yaml_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
|
| 627 |
-
self.download_file_if_not_exists(yaml_url, download_path)
|
| 628 |
-
continue
|
| 629 |
-
|
| 630 |
-
# For regular model files, try UVR repo first, then audio-separator repo
|
| 631 |
try:
|
| 632 |
-
|
| 633 |
-
self.download_file_if_not_exists(download_url, download_path)
|
| 634 |
except RuntimeError:
|
| 635 |
-
self.
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
| 640 |
-
|
| 641 |
-
raise ValueError(f"Model file {model_filename} not found in supported model files")
|
| 642 |
|
| 643 |
def load_model_data_from_yaml(self, yaml_config_filename):
|
| 644 |
-
"""
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
|
| 657 |
-
self.logger.debug(f"Model data loaded from YAML file: {model_data}")
|
| 658 |
-
|
| 659 |
-
if "roformer" in model_data_yaml_filepath:
|
| 660 |
-
model_data["is_roformer"] = True
|
| 661 |
-
|
| 662 |
-
return model_data
|
| 663 |
|
| 664 |
def load_model_data_using_hash(self, model_path):
|
| 665 |
-
"""
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
|
| 672 |
-
|
| 673 |
-
vr_model_data_url = f"{model_data_url_prefix}/vr_model_data/model_data_new.json"
|
| 674 |
-
mdx_model_data_url = f"{model_data_url_prefix}/mdx_model_data/model_data_new.json"
|
| 675 |
-
|
| 676 |
-
# Calculate hash for the downloaded model
|
| 677 |
-
self.logger.debug("Calculating MD5 hash for model file to identify model parameters from UVR data...")
|
| 678 |
model_hash = self.get_model_hash(model_path)
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
self.logger.
|
| 692 |
-
|
| 693 |
-
mdx_model_data_object = json.load(open(mdx_model_data_path, encoding="utf-8"))
|
| 694 |
-
|
| 695 |
-
# Load additional model data from audio-separator
|
| 696 |
-
self.logger.debug("Loading additional model parameters from audio-separator model data file...")
|
| 697 |
-
with resources.open_text("audio_separator", "model-data.json") as f:
|
| 698 |
-
audio_separator_model_data = json.load(f)
|
| 699 |
-
|
| 700 |
-
# Merge the model data objects, with audio-separator data taking precedence
|
| 701 |
-
vr_model_data_object = {**vr_model_data_object, **audio_separator_model_data.get("vr_model_data", {})}
|
| 702 |
-
mdx_model_data_object = {**mdx_model_data_object, **audio_separator_model_data.get("mdx_model_data", {})}
|
| 703 |
-
|
| 704 |
-
if model_hash in mdx_model_data_object:
|
| 705 |
-
model_data = mdx_model_data_object[model_hash]
|
| 706 |
-
elif model_hash in vr_model_data_object:
|
| 707 |
-
model_data = vr_model_data_object[model_hash]
|
| 708 |
-
else:
|
| 709 |
-
raise ValueError(f"Unsupported Model File: parameters for MD5 hash {model_hash} could not be found in UVR model data file for MDX or VR arch.")
|
| 710 |
-
|
| 711 |
-
self.logger.debug(f"Model data loaded using hash {model_hash}: {model_data}")
|
| 712 |
|
| 713 |
-
return model_data
|
| 714 |
-
|
| 715 |
-
def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"):
|
| 716 |
-
"""
|
| 717 |
-
This method instantiates the architecture-specific separation class,
|
| 718 |
-
loading the separation model into memory, downloading it first if necessary.
|
| 719 |
-
"""
|
| 720 |
-
self.logger.info(f"Loading model {model_filename}...")
|
| 721 |
-
|
| 722 |
-
load_model_start_time = time.perf_counter()
|
| 723 |
-
|
| 724 |
-
# Setting up the model path
|
| 725 |
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 726 |
model_name = model_filename.split(".")[0]
|
| 727 |
-
self.logger.debug(f"Model downloaded, friendly name: {model_friendly_name}, model_path: {model_path}")
|
| 728 |
|
| 729 |
-
|
| 730 |
-
yaml_config_filename = model_path
|
| 731 |
-
|
| 732 |
-
if yaml_config_filename is not None:
|
| 733 |
-
model_data = self.load_model_data_from_yaml(yaml_config_filename)
|
| 734 |
-
else:
|
| 735 |
-
model_data = self.load_model_data_using_hash(model_path)
|
| 736 |
|
| 737 |
common_params = {
|
| 738 |
"logger": self.logger,
|
|
@@ -755,205 +353,161 @@ class Separator:
|
|
| 755 |
"use_soundfile": self.use_soundfile,
|
| 756 |
}
|
| 757 |
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
|
|
|
| 763 |
|
|
|
|
|
|
|
| 764 |
if model_type == "Demucs" and sys.version_info < (3, 10):
|
| 765 |
-
raise Exception("Demucs
|
| 766 |
-
|
| 767 |
-
self.logger.debug(f"Importing module for model type {model_type}: {separator_classes[model_type]}")
|
| 768 |
|
| 769 |
module_name, class_name = separator_classes[model_type].split(".")
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
self.logger.info(f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
|
| 780 |
def separate(self, audio_file_path, custom_output_names=None):
|
| 781 |
-
"""
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
This method takes the path to an audio file or a directory containing audio files, processes them through
|
| 785 |
-
the loaded separation model, and returns the paths to the output files containing the separated audio stems.
|
| 786 |
-
It handles the entire flow from loading the audio, running the separation, clearing up resources, and logging the process.
|
| 787 |
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
- custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
| 791 |
|
| 792 |
-
Returns:
|
| 793 |
-
- output_files (list of str): A list containing the paths to the separated audio stem files.
|
| 794 |
-
"""
|
| 795 |
-
# Check if the model and device are properly initialized
|
| 796 |
-
if not (self.torch_device and self.model_instance):
|
| 797 |
-
raise ValueError("Initialization failed or model not loaded. Please load a model before attempting to separate.")
|
| 798 |
-
|
| 799 |
-
# If audio_file_path is a string, convert it to a list for uniform processing
|
| 800 |
if isinstance(audio_file_path, str):
|
| 801 |
audio_file_path = [audio_file_path]
|
| 802 |
|
| 803 |
-
# Initialize a list to store paths of all output files
|
| 804 |
output_files = []
|
| 805 |
-
|
| 806 |
-
# Process each path in the list
|
| 807 |
for path in audio_file_path:
|
| 808 |
if os.path.isdir(path):
|
| 809 |
-
|
| 810 |
-
for root, dirs, files in os.walk(path):
|
| 811 |
for file in files:
|
| 812 |
-
|
| 813 |
-
if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed
|
| 814 |
full_path = os.path.join(root, file)
|
| 815 |
-
self.
|
| 816 |
-
try:
|
| 817 |
-
# Perform separation for each file
|
| 818 |
-
files_output = self._separate_file(full_path, custom_output_names)
|
| 819 |
-
output_files.extend(files_output)
|
| 820 |
-
except Exception as e:
|
| 821 |
-
self.logger.error(f"Failed to process file {full_path}: {e}")
|
| 822 |
else:
|
| 823 |
-
|
| 824 |
-
self.logger.info(f"Processing file: {path}")
|
| 825 |
-
try:
|
| 826 |
-
files_output = self._separate_file(path, custom_output_names)
|
| 827 |
-
output_files.extend(files_output)
|
| 828 |
-
except Exception as e:
|
| 829 |
-
self.logger.error(f"Failed to process file {path}: {e}")
|
| 830 |
|
|
|
|
|
|
|
| 831 |
return output_files
|
| 832 |
|
| 833 |
def _separate_file(self, audio_file_path, custom_output_names=None):
|
| 834 |
-
"""
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
"""
|
| 844 |
-
# Log the start of the separation process
|
| 845 |
-
self.logger.info(f"Starting separation process for audio_file_path: {audio_file_path}")
|
| 846 |
-
separate_start_time = time.perf_counter()
|
| 847 |
-
|
| 848 |
-
# Log normalization and amplification thresholds
|
| 849 |
-
self.logger.debug(f"Normalization threshold set to {self.normalization_threshold}, waveform will be lowered to this max amplitude to avoid clipping.")
|
| 850 |
-
self.logger.debug(f"Amplification threshold set to {self.amplification_threshold}, waveform will be scaled up to this max amplitude if below it.")
|
| 851 |
-
|
| 852 |
-
# Run separation method for the loaded model with autocast enabled if supported by the device
|
| 853 |
-
output_files = None
|
| 854 |
-
if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
|
| 855 |
-
self.logger.debug("Autocast available.")
|
| 856 |
-
with autocast_mode.autocast(self.torch_device.type):
|
| 857 |
-
output_files = self.model_instance.separate(audio_file_path, custom_output_names)
|
| 858 |
-
else:
|
| 859 |
-
self.logger.debug("Autocast unavailable.")
|
| 860 |
-
output_files = self.model_instance.separate(audio_file_path, custom_output_names)
|
| 861 |
-
|
| 862 |
-
# Clear GPU cache to free up memory
|
| 863 |
-
self.model_instance.clear_gpu_cache()
|
| 864 |
-
|
| 865 |
-
# Unset separation parameters to prevent accidentally re-using the wrong source files or output paths
|
| 866 |
-
self.model_instance.clear_file_specific_paths()
|
| 867 |
|
| 868 |
-
|
| 869 |
-
self.print_uvr_vip_message()
|
| 870 |
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 874 |
|
|
|
|
|
|
|
| 875 |
return output_files
|
| 876 |
|
| 877 |
def download_model_and_data(self, model_filename):
|
| 878 |
-
"""
|
| 879 |
-
|
| 880 |
-
"""
|
| 881 |
-
self.logger.info(f"Downloading model {model_filename}...")
|
| 882 |
-
|
| 883 |
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
yaml_config_filename = model_path
|
| 887 |
-
|
| 888 |
-
if yaml_config_filename is not None:
|
| 889 |
-
model_data = self.load_model_data_from_yaml(yaml_config_filename)
|
| 890 |
-
else:
|
| 891 |
-
model_data = self.load_model_data_using_hash(model_path)
|
| 892 |
-
|
| 893 |
-
model_data_dict_size = len(model_data)
|
| 894 |
-
|
| 895 |
-
self.logger.info(f"Model downloaded, type: {model_type}, friendly name: {model_friendly_name}, model_path: {model_path}, model_data: {model_data_dict_size} items")
|
| 896 |
|
| 897 |
def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
|
| 898 |
-
"""
|
| 899 |
-
Returns a simplified, user-friendly list of models with their key metrics.
|
| 900 |
-
Optionally sorts the list based on the specified criteria.
|
| 901 |
-
|
| 902 |
-
:param sort_by: Criteria to sort by. Can be "name", "filename", or any stem name
|
| 903 |
-
"""
|
| 904 |
model_files = self.list_supported_model_files()
|
| 905 |
simplified_list = {}
|
| 906 |
|
| 907 |
for model_type, models in model_files.items():
|
| 908 |
for name, data in models.items():
|
| 909 |
filename = data["filename"]
|
| 910 |
-
scores = data.get("scores"
|
| 911 |
-
stems = data.get("stems"
|
| 912 |
target_stem = data.get("target_stem")
|
| 913 |
-
|
| 914 |
-
# Format stems with their SDR scores where available
|
| 915 |
stems_with_scores = []
|
| 916 |
stem_sdr_dict = {}
|
| 917 |
-
|
| 918 |
-
# Process each stem from the model's stem list
|
| 919 |
for stem in stems:
|
| 920 |
-
stem_scores = scores.get(stem, {})
|
| 921 |
-
# Add asterisk if this is the target stem
|
| 922 |
stem_display = f"{stem}*" if stem == target_stem else stem
|
| 923 |
-
|
| 924 |
-
if
|
| 925 |
-
sdr = round(
|
| 926 |
stems_with_scores.append(f"{stem_display} ({sdr})")
|
| 927 |
stem_sdr_dict[stem.lower()] = sdr
|
| 928 |
else:
|
| 929 |
-
# Include stem without SDR score
|
| 930 |
stems_with_scores.append(stem_display)
|
| 931 |
stem_sdr_dict[stem.lower()] = None
|
| 932 |
|
| 933 |
-
# If no stems listed, mark as Unknown
|
| 934 |
if not stems_with_scores:
|
| 935 |
stems_with_scores = ["Unknown"]
|
| 936 |
stem_sdr_dict["unknown"] = None
|
| 937 |
|
| 938 |
-
simplified_list[filename] = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 939 |
|
| 940 |
-
# Sort and filter the list if a sort_by parameter is provided
|
| 941 |
if filter_sort_by:
|
| 942 |
if filter_sort_by == "name":
|
| 943 |
return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
|
| 944 |
elif filter_sort_by == "filename":
|
| 945 |
return dict(sorted(simplified_list.items()))
|
| 946 |
else:
|
| 947 |
-
# Convert sort_by to lowercase for case-insensitive comparison
|
| 948 |
sort_by_lower = filter_sort_by.lower()
|
| 949 |
-
# Filter out models that don't have the specified stem
|
| 950 |
filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
|
| 951 |
-
|
| 952 |
-
# Sort by SDR score if available, putting None values last
|
| 953 |
def sort_key(item):
|
| 954 |
-
sdr = item[1]["SDR"]
|
| 955 |
return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
|
| 956 |
-
|
| 957 |
return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
|
| 958 |
|
| 959 |
-
return simplified_list
|
|
|
|
| 22 |
from tqdm import tqdm
|
| 23 |
|
| 24 |
|
| 25 |
+
import os
|
| 26 |
+
import logging
|
| 27 |
+
import requests
|
| 28 |
+
import torch
|
| 29 |
+
import torch.amp.autocast_mode as autocast_mode
|
| 30 |
+
import onnxruntime as ort
|
| 31 |
+
import numpy as np
|
| 32 |
+
import soundfile as sf
|
| 33 |
+
import json
|
| 34 |
+
import yaml
|
| 35 |
+
import importlib
|
| 36 |
+
import hashlib
|
| 37 |
+
import time
|
| 38 |
+
from typing import Optional
|
| 39 |
+
from io import BytesIO
|
| 40 |
+
from tqdm import tqdm
|
| 41 |
+
|
| 42 |
class Separator:
|
| 43 |
"""
|
| 44 |
+
Optimized Separator class for audio source separation on Hugging Face Zero GPU.
|
| 45 |
+
Supports MDX, VR, Demucs, and MDXC architectures with ONNX Runtime and PyTorch.
|
| 46 |
+
Optimized for memory efficiency, fast inference, and serverless environments.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
"""
|
|
|
|
| 48 |
def __init__(
|
| 49 |
self,
|
| 50 |
log_level=logging.INFO,
|
|
|
|
| 51 |
model_file_dir="/tmp/audio-separator-models/",
|
| 52 |
+
output_dir="/tmp/audio_output/",
|
| 53 |
output_format="WAV",
|
| 54 |
output_bitrate=None,
|
| 55 |
normalization_threshold=0.9,
|
|
|
|
| 57 |
output_single_stem=None,
|
| 58 |
invert_using_spec=False,
|
| 59 |
sample_rate=44100,
|
| 60 |
+
use_soundfile=True,
|
| 61 |
+
use_autocast=True,
|
| 62 |
use_directml=False,
|
| 63 |
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
|
| 64 |
vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
|
|
|
|
| 66 |
mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
|
| 67 |
info_only=False,
|
| 68 |
):
|
| 69 |
+
"""Initialize the separator for Zero GPU."""
|
| 70 |
self.logger = logging.getLogger(__name__)
|
| 71 |
self.logger.setLevel(log_level)
|
| 72 |
+
handler = logging.StreamHandler()
|
| 73 |
+
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
if not self.logger.hasHandlers():
|
| 75 |
+
self.logger.addHandler(handler)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
# Configuration
|
| 78 |
+
self.model_file_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR", model_file_dir)
|
| 79 |
+
self.output_dir = output_dir or os.getcwd()
|
| 80 |
+
self.output_format = output_format.upper()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
self.output_bitrate = output_bitrate
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
self.normalization_threshold = normalization_threshold
|
|
|
|
|
|
|
|
|
|
| 83 |
self.amplification_threshold = amplification_threshold
|
|
|
|
|
|
|
|
|
|
| 84 |
self.output_single_stem = output_single_stem
|
|
|
|
|
|
|
|
|
|
| 85 |
self.invert_using_spec = invert_using_spec
|
| 86 |
+
self.sample_rate = int(sample_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
self.use_soundfile = use_soundfile
|
| 88 |
self.use_autocast = use_autocast
|
| 89 |
self.use_directml = use_directml
|
|
|
|
|
|
|
|
|
|
| 90 |
self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
|
| 91 |
|
| 92 |
+
# Validation
|
| 93 |
+
if not (0 < normalization_threshold <= 1):
|
| 94 |
+
raise ValueError("normalization_threshold must be in (0, 1]")
|
| 95 |
+
if not (0 <= amplification_threshold <= 1):
|
| 96 |
+
raise ValueError("amplification_threshold must be in [0, 1]")
|
| 97 |
+
if self.sample_rate <= 0 or self.sample_rate > 12800000:
|
| 98 |
+
raise ValueError("sample_rate must be a positive integer <= 12800000")
|
| 99 |
|
| 100 |
+
# Create directories
|
| 101 |
+
os.makedirs(self.model_file_dir, exist_ok=True)
|
| 102 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 103 |
|
| 104 |
+
# Setup device
|
| 105 |
+
self.torch_device_cpu = torch.device("cpu")
|
| 106 |
+
self.torch_device_mps = torch.device("mps") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else None
|
| 107 |
+
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 108 |
+
self.onnx_execution_provider = ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
|
| 109 |
+
|
| 110 |
+
if self.use_directml:
|
| 111 |
+
try:
|
| 112 |
+
import torch_directml
|
| 113 |
+
if torch_directml.is_available():
|
| 114 |
+
self.torch_device = torch_directml.device()
|
| 115 |
+
self.onnx_execution_provider = ["DmlExecutionProvider"] if "DmlExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
|
| 116 |
+
except ImportError:
|
| 117 |
+
self.logger.warning("torch_directml not installed, falling back to CPU")
|
| 118 |
+
self.torch_device = self.torch_device_cpu
|
| 119 |
+
|
| 120 |
+
self.logger.info(f"Using device: {self.torch_device}, ONNX provider: {self.onnx_execution_provider}")
|
| 121 |
+
self.model_instance = None
|
| 122 |
self.model_is_uvr_vip = False
|
| 123 |
self.model_friendly_name = None
|
| 124 |
|
| 125 |
if not info_only:
|
| 126 |
+
self.logger.info(f"Initialized Separator with model_dir: {self.model_file_dir}, output_dir: {self.output_dir}")
|
| 127 |
|
| 128 |
def setup_accelerated_inferencing_device(self):
|
| 129 |
+
"""Configure hardware acceleration."""
|
| 130 |
+
if self.torch_device.type == "cuda":
|
| 131 |
+
self.logger.info("CUDA available, using GPU acceleration")
|
| 132 |
+
elif self.torch_device_mps and "arm" in platform.machine().lower():
|
| 133 |
+
self.torch_device = self.torch_device_mps
|
| 134 |
+
self.logger.info("MPS available, using Apple Silicon acceleration")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
else:
|
| 136 |
+
self.logger.info("No GPU acceleration available, using CPU")
|
| 137 |
+
self.torch_device = self.torch_device_cpu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
def get_model_hash(self, model_path):
|
| 140 |
+
"""Calculate MD5 hash of a model file."""
|
| 141 |
+
BYTES_TO_HASH = 10000 * 1024
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
try:
|
|
|
|
|
|
|
| 143 |
with open(model_path, "rb") as f:
|
| 144 |
+
file_size = os.path.getsize(model_path)
|
| 145 |
if file_size < BYTES_TO_HASH:
|
| 146 |
+
return hashlib.md5(f.read()).hexdigest()
|
| 147 |
+
f.seek(file_size - BYTES_TO_HASH)
|
| 148 |
+
return hashlib.md5(f.read()).hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
except Exception as e:
|
|
|
|
| 150 |
self.logger.error(f"Error calculating hash for {model_path}: {e}")
|
| 151 |
+
raise
|
| 152 |
|
| 153 |
def download_file_if_not_exists(self, url, output_path):
|
| 154 |
+
"""Download file from URL if it doesn't exist."""
|
| 155 |
+
if os.path.exists(output_path):
|
| 156 |
+
self.logger.debug(f"File exists: {output_path}")
|
|
|
|
|
|
|
|
|
|
| 157 |
return
|
| 158 |
+
self.logger.info(f"Downloading {url} to {output_path}")
|
| 159 |
+
response = requests.get(url, stream=True, timeout=60)
|
|
|
|
|
|
|
| 160 |
if response.status_code == 200:
|
| 161 |
+
with open(output_path, "wb") as f, tqdm(total=int(response.headers.get("content-length", 0)), unit="B", unit_scale=True) as pbar:
|
|
|
|
|
|
|
|
|
|
| 162 |
for chunk in response.iter_content(chunk_size=8192):
|
|
|
|
| 163 |
f.write(chunk)
|
| 164 |
+
pbar.update(len(chunk))
|
| 165 |
else:
|
| 166 |
+
raise RuntimeError(f"Failed to download {url}: {response.status_code}")
|
| 167 |
|
| 168 |
def list_supported_model_files(self):
|
| 169 |
+
"""Fetch supported model files from predefined sources."""
|
| 170 |
+
download_checks_url = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
|
| 171 |
+
download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
|
| 172 |
+
self.download_file_if_not_exists(download_checks_url, download_checks_path)
|
| 173 |
+
model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
|
| 174 |
|
| 175 |
+
# Mock model scores for simplicity (replace with actual model-scores.json if available)
|
| 176 |
+
model_scores = {
|
| 177 |
+
"UVR-MDX-NET-Inst_full_292.onnx": {
|
| 178 |
+
"median_scores": {
|
| 179 |
+
"vocals": {"SDR": 10.6497, "SIR": 20.3786, "SAR": 10.692, "ISR": 14.848},
|
| 180 |
+
"instrumental": {"SDR": 15.2149, "SIR": 25.6075, "SAR": 17.1363, "ISR": 17.7893}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
},
|
| 182 |
+
"stems": ["vocals", "instrumental"],
|
| 183 |
+
"target_stem": "vocals"
|
|
|
|
|
|
|
| 184 |
},
|
| 185 |
+
"htdemucs_ft.yaml": {
|
| 186 |
+
"median_scores": {
|
| 187 |
+
"vocals": {"SDR": 11.2685, "SIR": 21.257, "SAR": 11.0359, "ISR": 19.3753},
|
| 188 |
+
"drums": {"SDR": 13.235, "SIR": 23.3053, "SAR": 13.0313, "ISR": 17.2889},
|
| 189 |
+
"bass": {"SDR": 9.72743, "SIR": 19.5435, "SAR": 9.20801, "ISR": 13.5037}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
},
|
| 191 |
+
"stems": ["vocals", "drums", "bass"],
|
| 192 |
+
"target_stem": "vocals"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
},
|
| 194 |
+
"MDX23C-8KFFT-InstVoc_HQ.ckpt": {
|
| 195 |
+
"median_scores": {
|
| 196 |
+
"vocals": {"SDR": 11.9504, "SIR": 23.1166, "SAR": 12.093, "ISR": 15.4782},
|
| 197 |
+
"instrumental": {"SDR": 16.3035, "SIR": 26.6161, "SAR": 18.5167, "ISR": 18.3939}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
},
|
| 199 |
+
"stems": ["vocals", "instrumental"],
|
| 200 |
+
"target_stem": "vocals"
|
|
|
|
|
|
|
|
|
|
| 201 |
}
|
| 202 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
+
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
|
| 205 |
+
audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
+
# Simplified model list for MDX, VR, Demucs, MDXC
|
| 208 |
model_files_grouped_by_type = {
|
| 209 |
+
"MDX": {
|
| 210 |
+
"MDX-Net Model: UVR-MDX-NET-Inst_full_292": {
|
| 211 |
+
"filename": "UVR-MDX-NET-Inst_full_292.onnx",
|
| 212 |
+
"scores": model_scores.get("UVR-MDX-NET-Inst_full_292.onnx", {}).get("median_scores", {}),
|
| 213 |
+
"stems": model_scores.get("UVR-MDX-NET-Inst_full_292.onnx", {}).get("stems", []),
|
| 214 |
+
"target_stem": model_scores.get("UVR-MDX-NET-Inst_full_292.onnx", {}).get("target_stem"),
|
| 215 |
+
"download_files": ["UVR-MDX-NET-Inst_full_292.onnx"]
|
| 216 |
+
}
|
| 217 |
+
},
|
| 218 |
"VR": {
|
| 219 |
+
"VR Model: UVR-VR-Model": {
|
| 220 |
+
"filename": "UVR-VR-Model.onnx",
|
| 221 |
+
"scores": {},
|
| 222 |
+
"stems": ["vocals", "instrumental"],
|
| 223 |
+
"target_stem": "vocals",
|
| 224 |
+
"download_files": ["UVR-VR-Model.onnx"]
|
| 225 |
+
}
|
|
|
|
| 226 |
},
|
| 227 |
+
"Demucs": {
|
| 228 |
+
"Demucs v4: htdemucs_ft": {
|
| 229 |
+
"filename": "htdemucs_ft.yaml",
|
| 230 |
+
"scores": model_scores.get("htdemucs_ft.yaml", {}).get("median_scores", {}),
|
| 231 |
+
"stems": model_scores.get("htdemucs_ft.yaml", {}).get("stems", []),
|
| 232 |
+
"target_stem": model_scores.get("htdemucs_ft.yaml", {}).get("target_stem"),
|
| 233 |
+
"download_files": [
|
| 234 |
+
f"{public_model_repo_url_prefix}/htdemucs_ft.yaml",
|
| 235 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th"
|
| 236 |
+
]
|
| 237 |
+
}
|
| 238 |
},
|
|
|
|
| 239 |
"MDXC": {
|
| 240 |
+
"MDX23C Model: MDX23C-InstVoc HQ": {
|
| 241 |
+
"filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt",
|
| 242 |
+
"scores": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("median_scores", {}),
|
| 243 |
+
"stems": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("stems", []),
|
| 244 |
+
"target_stem": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("target_stem"),
|
| 245 |
+
"download_files": [
|
| 246 |
+
"MDX23C-8KFFT-InstVoc_HQ.ckpt",
|
| 247 |
+
f"{audio_separator_models_repo_url_prefix}/model_2_stem_full_band_8k.yaml"
|
| 248 |
+
]
|
| 249 |
}
|
| 250 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
}
|
|
|
|
| 252 |
return model_files_grouped_by_type
|
| 253 |
|
| 254 |
def print_uvr_vip_message(self):
|
| 255 |
+
"""Print message for VIP models."""
|
|
|
|
|
|
|
| 256 |
if self.model_is_uvr_vip:
|
| 257 |
+
self.logger.warning(f"Model '{self.model_friendly_name}' is a VIP model. Consider supporting UVR at https://patreon.com/uvr")
|
|
|
|
| 258 |
|
| 259 |
def download_model_files(self, model_filename):
|
| 260 |
+
"""Download model files and return metadata."""
|
| 261 |
+
model_path = os.path.join(self.model_file_dir, model_filename)
|
| 262 |
+
supported_models = self.list_supported_model_files()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
|
| 264 |
vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
|
| 265 |
audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
|
| 266 |
|
| 267 |
yaml_config_filename = None
|
| 268 |
+
for model_type, models in supported_models.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
for model_friendly_name, model_info in models.items():
|
| 270 |
self.model_is_uvr_vip = "VIP" in model_friendly_name
|
| 271 |
model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
|
|
|
|
|
|
|
| 272 |
if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
|
|
|
|
| 273 |
self.model_friendly_name = model_friendly_name
|
| 274 |
self.print_uvr_vip_message()
|
|
|
|
|
|
|
| 275 |
for file_to_download in model_info["download_files"]:
|
|
|
|
| 276 |
if file_to_download.startswith("http"):
|
| 277 |
filename = file_to_download.split("/")[-1]
|
| 278 |
download_path = os.path.join(self.model_file_dir, filename)
|
| 279 |
self.download_file_if_not_exists(file_to_download, download_path)
|
| 280 |
+
if file_to_download.endswith(".yaml"):
|
| 281 |
+
yaml_config_filename = filename
|
| 282 |
continue
|
|
|
|
| 283 |
download_path = os.path.join(self.model_file_dir, file_to_download)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
try:
|
| 285 |
+
self.download_file_if_not_exists(f"{model_repo_url_prefix}/{file_to_download}", download_path)
|
|
|
|
| 286 |
except RuntimeError:
|
| 287 |
+
self.download_file_if_not_exists(f"{audio_separator_models_repo_url_prefix}/{file_to_download}", download_path)
|
| 288 |
+
if file_to_download.endswith(".yaml"):
|
| 289 |
+
yaml_config_filename = file_to_download
|
|
|
|
| 290 |
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
| 291 |
+
raise ValueError(f"Model file {model_filename} not found")
|
|
|
|
| 292 |
|
| 293 |
def load_model_data_from_yaml(self, yaml_config_filename):
|
| 294 |
+
"""Load model data from YAML file."""
|
| 295 |
+
yaml_path = os.path.join(self.model_file_dir, yaml_config_filename)
|
| 296 |
+
try:
|
| 297 |
+
with open(yaml_path, encoding="utf-8") as f:
|
| 298 |
+
model_data = yaml.load(f, Loader=yaml.FullLoader)
|
| 299 |
+
self.logger.debug(f"Model data loaded from YAML: {model_data}")
|
| 300 |
+
if "roformer" in yaml_config_filename.lower():
|
| 301 |
+
model_data["is_roformer"] = True
|
| 302 |
+
return model_data
|
| 303 |
+
except Exception as e:
|
| 304 |
+
self.logger.error(f"Failed to load YAML {yaml_config_filename}: {e}")
|
| 305 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
def load_model_data_using_hash(self, model_path):
|
| 308 |
+
"""Load model data using file hash."""
|
| 309 |
+
model_data_urls = [
|
| 310 |
+
"https://raw.githubusercontent.com/TRvlvr/application_data/main/vr_model_data/model_data_new.json",
|
| 311 |
+
"https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data_new.json"
|
| 312 |
+
]
|
| 313 |
+
model_data = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
model_hash = self.get_model_hash(model_path)
|
| 315 |
+
for url in model_data_urls:
|
| 316 |
+
model_data_path = os.path.join(self.model_file_dir, os.path.basename(url))
|
| 317 |
+
self.download_file_if_not_exists(url, model_data_path)
|
| 318 |
+
with open(model_data_path, encoding="utf-8") as f:
|
| 319 |
+
model_data.update(json.load(f))
|
| 320 |
+
if model_hash in model_data:
|
| 321 |
+
self.logger.debug(f"Model data loaded for hash {model_hash}")
|
| 322 |
+
return model_data[model_hash]
|
| 323 |
+
raise ValueError(f"No model data for hash {model_hash}")
|
| 324 |
+
|
| 325 |
+
def load_model(self, model_filename="UVR-MDX-NET-Inst_full_292.onnx"):
|
| 326 |
+
"""Load model based on architecture."""
|
| 327 |
+
self.logger.info(f"Loading model {model_filename}")
|
| 328 |
+
start_time = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 331 |
model_name = model_filename.split(".")[0]
|
|
|
|
| 332 |
|
| 333 |
+
model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
common_params = {
|
| 336 |
"logger": self.logger,
|
|
|
|
| 353 |
"use_soundfile": self.use_soundfile,
|
| 354 |
}
|
| 355 |
|
| 356 |
+
separator_classes = {
|
| 357 |
+
"MDX": "mdx_separator.MDXSeparator",
|
| 358 |
+
"VR": "vr_separator.VRSeparator",
|
| 359 |
+
"Demucs": "demucs_separator.DemucsSeparator",
|
| 360 |
+
"MDXC": "mdxc_separator.MDXCSeparator"
|
| 361 |
+
}
|
| 362 |
|
| 363 |
+
if model_type not in separator_classes:
|
| 364 |
+
raise ValueError(f"Unsupported model type: {model_type}")
|
| 365 |
if model_type == "Demucs" and sys.version_info < (3, 10):
|
| 366 |
+
raise Exception("Demucs requires Python 3.10 or newer")
|
|
|
|
|
|
|
| 367 |
|
| 368 |
module_name, class_name = separator_classes[model_type].split(".")
|
| 369 |
+
try:
|
| 370 |
+
module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}")
|
| 371 |
+
separator_class = getattr(module, class_name)
|
| 372 |
+
self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
|
| 373 |
+
except ImportError as e:
|
| 374 |
+
self.logger.error(f"Failed to load module {module_name}: {e}")
|
| 375 |
+
raise
|
| 376 |
+
|
| 377 |
+
self.logger.info(f"Model loaded in {time.perf_counter() - start_time:.2f} seconds")
|
| 378 |
+
|
| 379 |
+
def preprocess_audio(self, audio_data, sample_rate):
|
| 380 |
+
"""Preprocess audio: resample, normalize, and convert to tensor."""
|
| 381 |
+
if sample_rate != self.sample_rate:
|
| 382 |
+
self.logger.debug(f"Resampling from {sample_rate} to {self.sample_rate} Hz")
|
| 383 |
+
audio_data = np.interp(
|
| 384 |
+
np.linspace(0, len(audio_data), int(len(audio_data) * self.sample_rate / sample_rate)),
|
| 385 |
+
np.arange(len(audio_data)),
|
| 386 |
+
audio_data
|
| 387 |
+
)
|
| 388 |
+
max_amplitude = np.max(np.abs(audio_data))
|
| 389 |
+
if max_amplitude > 0:
|
| 390 |
+
audio_data = audio_data * (self.normalization_threshold / max_amplitude)
|
| 391 |
+
if max_amplitude < self.amplification_threshold:
|
| 392 |
+
audio_data = audio_data * (self.amplification_threshold / max_amplitude)
|
| 393 |
+
return torch.tensor(audio_data, dtype=torch.float32, device=self.torch_device)
|
| 394 |
|
| 395 |
def separate(self, audio_file_path, custom_output_names=None):
|
| 396 |
+
"""Separate audio file into stems."""
|
| 397 |
+
if not self.model_instance:
|
| 398 |
+
raise ValueError("Model not loaded")
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
+
self.logger.info(f"Separating audio: {audio_file_path}")
|
| 401 |
+
start_time = time.perf_counter()
|
|
|
|
| 402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
if isinstance(audio_file_path, str):
|
| 404 |
audio_file_path = [audio_file_path]
|
| 405 |
|
|
|
|
| 406 |
output_files = []
|
|
|
|
|
|
|
| 407 |
for path in audio_file_path:
|
| 408 |
if os.path.isdir(path):
|
| 409 |
+
for root, _, files in os.walk(path):
|
|
|
|
| 410 |
for file in files:
|
| 411 |
+
if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff")):
|
|
|
|
| 412 |
full_path = os.path.join(root, file)
|
| 413 |
+
output_files.extend(self._separate_file(full_path, custom_output_names))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
else:
|
| 415 |
+
output_files.extend(self._separate_file(path, custom_output_names))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
+
self.print_uvr_vip_message()
|
| 418 |
+
self.logger.info(f"Separation completed in {time.perf_counter() - start_time:.2f} seconds")
|
| 419 |
return output_files
|
| 420 |
|
| 421 |
def _separate_file(self, audio_file_path, custom_output_names=None):
|
| 422 |
+
"""Internal method to separate a single audio file."""
|
| 423 |
+
self.logger.debug(f"Processing file: {audio_file_path}")
|
| 424 |
+
try:
|
| 425 |
+
audio_data, input_sample_rate = sf.read(audio_file_path)
|
| 426 |
+
if len(audio_data.shape) > 1:
|
| 427 |
+
audio_data = np.mean(audio_data, axis=1)
|
| 428 |
+
except Exception as e:
|
| 429 |
+
self.logger.error(f"Failed to read audio: {e}")
|
| 430 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
+
audio_tensor = self.preprocess_audio(audio_data, input_sample_rate)
|
|
|
|
| 433 |
|
| 434 |
+
output_files = []
|
| 435 |
+
try:
|
| 436 |
+
if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
|
| 437 |
+
with autocast_mode.autocast(self.torch_device.type):
|
| 438 |
+
output_files = self.model_instance.separate(audio_tensor, custom_output_names)
|
| 439 |
+
else:
|
| 440 |
+
output_files = self.model_instance.separate(audio_tensor, custom_output_names)
|
| 441 |
+
except Exception as e:
|
| 442 |
+
self.logger.error(f"Separation failed: {e}")
|
| 443 |
+
raise
|
| 444 |
+
|
| 445 |
+
# Mock output for architectures not implemented (replace with actual logic in separator classes)
|
| 446 |
+
if not output_files:
|
| 447 |
+
stem_names = ["vocals", "instrumental"] # Adjust based on model
|
| 448 |
+
output_files = []
|
| 449 |
+
for stem in stem_names:
|
| 450 |
+
output_path = os.path.join(self.output_dir, f"{os.path.splitext(os.path.basename(audio_file_path))[0]}_{stem}.{self.output_format.lower()}")
|
| 451 |
+
sf.write(output_path, audio_data, self.sample_rate)
|
| 452 |
+
output_files.append(output_path)
|
| 453 |
|
| 454 |
+
self.model_instance.clear_gpu_cache()
|
| 455 |
+
self.model_instance.clear_file_specific_paths()
|
| 456 |
return output_files
|
| 457 |
|
| 458 |
def download_model_and_data(self, model_filename):
|
| 459 |
+
"""Download model files without loading into memory."""
|
| 460 |
+
self.logger.info(f"Downloading model {model_filename}")
|
|
|
|
|
|
|
|
|
|
| 461 |
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 462 |
+
model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
|
| 463 |
+
self.logger.info(f"Model downloaded: {model_friendly_name}, type: {model_type}, path: {model_path}, data items: {len(model_data)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
|
| 465 |
def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
|
| 466 |
+
"""Return a simplified list of models."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
model_files = self.list_supported_model_files()
|
| 468 |
simplified_list = {}
|
| 469 |
|
| 470 |
for model_type, models in model_files.items():
|
| 471 |
for name, data in models.items():
|
| 472 |
filename = data["filename"]
|
| 473 |
+
scores = data.get("scores", {})
|
| 474 |
+
stems = data.get("stems", [])
|
| 475 |
target_stem = data.get("target_stem")
|
|
|
|
|
|
|
| 476 |
stems_with_scores = []
|
| 477 |
stem_sdr_dict = {}
|
|
|
|
|
|
|
| 478 |
for stem in stems:
|
|
|
|
|
|
|
| 479 |
stem_display = f"{stem}*" if stem == target_stem else stem
|
| 480 |
+
sdr = scores.get(stem, {}).get("SDR")
|
| 481 |
+
if sdr is not None:
|
| 482 |
+
sdr = round(sdr, 1)
|
| 483 |
stems_with_scores.append(f"{stem_display} ({sdr})")
|
| 484 |
stem_sdr_dict[stem.lower()] = sdr
|
| 485 |
else:
|
|
|
|
| 486 |
stems_with_scores.append(stem_display)
|
| 487 |
stem_sdr_dict[stem.lower()] = None
|
| 488 |
|
|
|
|
| 489 |
if not stems_with_scores:
|
| 490 |
stems_with_scores = ["Unknown"]
|
| 491 |
stem_sdr_dict["unknown"] = None
|
| 492 |
|
| 493 |
+
simplified_list[filename] = {
|
| 494 |
+
"Name": name,
|
| 495 |
+
"Type": model_type,
|
| 496 |
+
"Stems": stems_with_scores,
|
| 497 |
+
"SDR": stem_sdr_dict
|
| 498 |
+
}
|
| 499 |
|
|
|
|
| 500 |
if filter_sort_by:
|
| 501 |
if filter_sort_by == "name":
|
| 502 |
return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
|
| 503 |
elif filter_sort_by == "filename":
|
| 504 |
return dict(sorted(simplified_list.items()))
|
| 505 |
else:
|
|
|
|
| 506 |
sort_by_lower = filter_sort_by.lower()
|
|
|
|
| 507 |
filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
|
|
|
|
|
|
|
| 508 |
def sort_key(item):
|
| 509 |
+
sdr = item[1]["SDR"].get(sort_by_lower)
|
| 510 |
return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
|
|
|
|
| 511 |
return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
|
| 512 |
|
| 513 |
+
return simplified_list
|