Spaces:
Running
Running
#!/usr/bin/env python3 | |
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | |
import logging | |
import os | |
from functools import lru_cache | |
import numpy as np | |
import sherpa_onnx | |
import soundfile as sf | |
from huggingface_hub import hf_hub_download | |
import uuid | |
def convert_to_wav(in_filename: str) -> str: | |
"""Convert the input audio file to a wave file""" | |
out_filename = str(uuid.uuid4()) | |
out_filename = f"{in_filename}.wav" | |
logging.info(f"Converting '{in_filename}' to '{out_filename}'") | |
_ = os.system( | |
f"ffmpeg -hide_banner -loglevel error -i '{in_filename}' -ar 44100 -ac 2 '{out_filename}' -y" | |
) | |
return out_filename | |
def load_audio(filename): | |
filename = convert_to_wav(filename) | |
samples, sample_rate = sf.read(filename, dtype="float32", always_2d=True) | |
samples = np.transpose(samples) | |
# now samples is of shape (num_channels, num_samples) | |
assert ( | |
samples.shape[1] > samples.shape[0] | |
), f"You should use (num_channels, num_samples). {samples.shape}" | |
assert ( | |
samples.dtype == np.float32 | |
), f"Expect np.float32 as dtype. Given: {samples.dtype}" | |
return samples, sample_rate | |
def get_file( | |
repo_id: str, | |
filename: str, | |
subfolder: str = ".", | |
) -> str: | |
nn_model_filename = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
subfolder=subfolder, | |
) | |
return nn_model_filename | |
def load_model(name: str): | |
name = name.split("|")[0] | |
if "spleeter" in name: | |
return load_spleeter_model(name) | |
elif "UVR" in name: | |
return load_uvr_model(name) | |
raise ValueError(f"Unsupported model name {name}") | |
def load_uvr_model(name: str): | |
model = get_file( | |
repo_id="k2-fsa/sherpa-onnx-models", | |
subfolder="source-separation-models", | |
filename=name, | |
) | |
config = sherpa_onnx.OfflineSourceSeparationConfig( | |
model=sherpa_onnx.OfflineSourceSeparationModelConfig( | |
uvr=sherpa_onnx.OfflineSourceSeparationUvrModelConfig( | |
model=model, | |
), | |
num_threads=2, | |
debug=False, | |
provider="cpu", | |
) | |
) | |
if not config.validate(): | |
raise ValueError("Please check your config.") | |
return sherpa_onnx.OfflineSourceSeparation(config) | |
def load_spleeter_model(name: str): | |
if "fp16" in name: | |
suffix = "fp16.onnx" | |
elif "int8" in name: | |
suffix = "int8.onnx" | |
else: | |
suffix = "onnx" | |
vocals = get_file(repo_id=f"csukuangfj/{name}", filename=f"vocals.{suffix}") | |
accompaniment = get_file( | |
repo_id=f"csukuangfj/{name}", filename=f"accompaniment.{suffix}" | |
) | |
config = sherpa_onnx.OfflineSourceSeparationConfig( | |
model=sherpa_onnx.OfflineSourceSeparationModelConfig( | |
spleeter=sherpa_onnx.OfflineSourceSeparationSpleeterModelConfig( | |
vocals=vocals, | |
accompaniment=accompaniment, | |
), | |
num_threads=2, | |
debug=False, | |
provider="cpu", | |
) | |
) | |
if not config.validate(): | |
raise ValueError("Please check your config.") | |
return sherpa_onnx.OfflineSourceSeparation(config) | |
model_list = [ | |
"sherpa-onnx-spleeter-2stems|fastest", | |
"sherpa-onnx-spleeter-2stems-fp16|fastest", | |
"sherpa-onnx-spleeter-2stems-int8|fastest", | |
"UVR_MDXNET_1_9703.onnx|slow", | |
"UVR_MDXNET_2_9682.onnx|slow", | |
"UVR_MDXNET_3_9662.onnx|slow", | |
"UVR_MDXNET_9482.onnx|slow", | |
"UVR_MDXNET_KARA.onnx|slow", | |
"UVR_MDXNET_KARA_2.onnx|slowest", | |
"UVR_MDXNET_Main.onnx|slowest", | |
"UVR-MDX-NET-Inst_1.onnx|slowest", | |
"UVR-MDX-NET-Inst_2.onnx|slowest", | |
"UVR-MDX-NET-Inst_3.onnx|slowest", | |
"UVR-MDX-NET-Inst_HQ_1.onnx|slowest", | |
"UVR-MDX-NET-Inst_HQ_2.onnx|slowest", | |
"UVR-MDX-NET-Inst_HQ_3.onnx|slowest", | |
"UVR-MDX-NET-Inst_HQ_4.onnx|slowest", | |
"UVR-MDX-NET-Inst_HQ_5.onnx|slowest", | |
"UVR-MDX-NET-Inst_Main.onnx|slowest", | |
"UVR-MDX-NET-Voc_FT.onnx|slowest", | |
"UVR-MDX-NET_Crowd_HQ_1.onnx|slowest", | |
] | |