source-separation / separate.py
csukuangfj's picture
add model info
d46145c
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import logging
import os
from functools import lru_cache
import numpy as np
import sherpa_onnx
import soundfile as sf
from huggingface_hub import hf_hub_download
import uuid
def convert_to_wav(in_filename: str) -> str:
"""Convert the input audio file to a wave file"""
out_filename = str(uuid.uuid4())
out_filename = f"{in_filename}.wav"
logging.info(f"Converting '{in_filename}' to '{out_filename}'")
_ = os.system(
f"ffmpeg -hide_banner -loglevel error -i '{in_filename}' -ar 44100 -ac 2 '{out_filename}' -y"
)
return out_filename
def load_audio(filename):
filename = convert_to_wav(filename)
samples, sample_rate = sf.read(filename, dtype="float32", always_2d=True)
samples = np.transpose(samples)
# now samples is of shape (num_channels, num_samples)
assert (
samples.shape[1] > samples.shape[0]
), f"You should use (num_channels, num_samples). {samples.shape}"
assert (
samples.dtype == np.float32
), f"Expect np.float32 as dtype. Given: {samples.dtype}"
return samples, sample_rate
@lru_cache(maxsize=10)
def get_file(
repo_id: str,
filename: str,
subfolder: str = ".",
) -> str:
nn_model_filename = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
)
return nn_model_filename
@lru_cache(maxsize=30)
def load_model(name: str):
name = name.split("|")[0]
if "spleeter" in name:
return load_spleeter_model(name)
elif "UVR" in name:
return load_uvr_model(name)
raise ValueError(f"Unsupported model name {name}")
def load_uvr_model(name: str):
model = get_file(
repo_id="k2-fsa/sherpa-onnx-models",
subfolder="source-separation-models",
filename=name,
)
config = sherpa_onnx.OfflineSourceSeparationConfig(
model=sherpa_onnx.OfflineSourceSeparationModelConfig(
uvr=sherpa_onnx.OfflineSourceSeparationUvrModelConfig(
model=model,
),
num_threads=2,
debug=False,
provider="cpu",
)
)
if not config.validate():
raise ValueError("Please check your config.")
return sherpa_onnx.OfflineSourceSeparation(config)
def load_spleeter_model(name: str):
if "fp16" in name:
suffix = "fp16.onnx"
elif "int8" in name:
suffix = "int8.onnx"
else:
suffix = "onnx"
vocals = get_file(repo_id=f"csukuangfj/{name}", filename=f"vocals.{suffix}")
accompaniment = get_file(
repo_id=f"csukuangfj/{name}", filename=f"accompaniment.{suffix}"
)
config = sherpa_onnx.OfflineSourceSeparationConfig(
model=sherpa_onnx.OfflineSourceSeparationModelConfig(
spleeter=sherpa_onnx.OfflineSourceSeparationSpleeterModelConfig(
vocals=vocals,
accompaniment=accompaniment,
),
num_threads=2,
debug=False,
provider="cpu",
)
)
if not config.validate():
raise ValueError("Please check your config.")
return sherpa_onnx.OfflineSourceSeparation(config)
model_list = [
"sherpa-onnx-spleeter-2stems|fastest",
"sherpa-onnx-spleeter-2stems-fp16|fastest",
"sherpa-onnx-spleeter-2stems-int8|fastest",
"UVR_MDXNET_1_9703.onnx|slow",
"UVR_MDXNET_2_9682.onnx|slow",
"UVR_MDXNET_3_9662.onnx|slow",
"UVR_MDXNET_9482.onnx|slow",
"UVR_MDXNET_KARA.onnx|slow",
"UVR_MDXNET_KARA_2.onnx|slowest",
"UVR_MDXNET_Main.onnx|slowest",
"UVR-MDX-NET-Inst_1.onnx|slowest",
"UVR-MDX-NET-Inst_2.onnx|slowest",
"UVR-MDX-NET-Inst_3.onnx|slowest",
"UVR-MDX-NET-Inst_HQ_1.onnx|slowest",
"UVR-MDX-NET-Inst_HQ_2.onnx|slowest",
"UVR-MDX-NET-Inst_HQ_3.onnx|slowest",
"UVR-MDX-NET-Inst_HQ_4.onnx|slowest",
"UVR-MDX-NET-Inst_HQ_5.onnx|slowest",
"UVR-MDX-NET-Inst_Main.onnx|slowest",
"UVR-MDX-NET-Voc_FT.onnx|slowest",
"UVR-MDX-NET_Crowd_HQ_1.onnx|slowest",
]