#!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) import logging import os from functools import lru_cache import numpy as np import sherpa_onnx import soundfile as sf from huggingface_hub import hf_hub_download import uuid def convert_to_wav(in_filename: str) -> str: """Convert the input audio file to a wave file""" out_filename = str(uuid.uuid4()) out_filename = f"{in_filename}.wav" logging.info(f"Converting '{in_filename}' to '{out_filename}'") _ = os.system( f"ffmpeg -hide_banner -loglevel error -i '{in_filename}' -ar 44100 -ac 2 '{out_filename}' -y" ) return out_filename def load_audio(filename): filename = convert_to_wav(filename) samples, sample_rate = sf.read(filename, dtype="float32", always_2d=True) samples = np.transpose(samples) # now samples is of shape (num_channels, num_samples) assert ( samples.shape[1] > samples.shape[0] ), f"You should use (num_channels, num_samples). {samples.shape}" assert ( samples.dtype == np.float32 ), f"Expect np.float32 as dtype. Given: {samples.dtype}" return samples, sample_rate @lru_cache(maxsize=10) def get_file( repo_id: str, filename: str, subfolder: str = ".", ) -> str: nn_model_filename = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, ) return nn_model_filename @lru_cache(maxsize=30) def load_model(name: str): name = name.split("|")[0] if "spleeter" in name: return load_spleeter_model(name) elif "UVR" in name: return load_uvr_model(name) raise ValueError(f"Unsupported model name {name}") def load_uvr_model(name: str): model = get_file( repo_id="k2-fsa/sherpa-onnx-models", subfolder="source-separation-models", filename=name, ) config = sherpa_onnx.OfflineSourceSeparationConfig( model=sherpa_onnx.OfflineSourceSeparationModelConfig( uvr=sherpa_onnx.OfflineSourceSeparationUvrModelConfig( model=model, ), num_threads=2, debug=False, provider="cpu", ) ) if not config.validate(): raise ValueError("Please check your config.") return sherpa_onnx.OfflineSourceSeparation(config) def load_spleeter_model(name: str): if "fp16" in name: suffix = "fp16.onnx" elif "int8" in name: suffix = "int8.onnx" else: suffix = "onnx" vocals = get_file(repo_id=f"csukuangfj/{name}", filename=f"vocals.{suffix}") accompaniment = get_file( repo_id=f"csukuangfj/{name}", filename=f"accompaniment.{suffix}" ) config = sherpa_onnx.OfflineSourceSeparationConfig( model=sherpa_onnx.OfflineSourceSeparationModelConfig( spleeter=sherpa_onnx.OfflineSourceSeparationSpleeterModelConfig( vocals=vocals, accompaniment=accompaniment, ), num_threads=2, debug=False, provider="cpu", ) ) if not config.validate(): raise ValueError("Please check your config.") return sherpa_onnx.OfflineSourceSeparation(config) model_list = [ "sherpa-onnx-spleeter-2stems|fastest", "sherpa-onnx-spleeter-2stems-fp16|fastest", "sherpa-onnx-spleeter-2stems-int8|fastest", "UVR_MDXNET_1_9703.onnx|slow", "UVR_MDXNET_2_9682.onnx|slow", "UVR_MDXNET_3_9662.onnx|slow", "UVR_MDXNET_9482.onnx|slow", "UVR_MDXNET_KARA.onnx|slow", "UVR_MDXNET_KARA_2.onnx|slowest", "UVR_MDXNET_Main.onnx|slowest", "UVR-MDX-NET-Inst_1.onnx|slowest", "UVR-MDX-NET-Inst_2.onnx|slowest", "UVR-MDX-NET-Inst_3.onnx|slowest", "UVR-MDX-NET-Inst_HQ_1.onnx|slowest", "UVR-MDX-NET-Inst_HQ_2.onnx|slowest", "UVR-MDX-NET-Inst_HQ_3.onnx|slowest", "UVR-MDX-NET-Inst_HQ_4.onnx|slowest", "UVR-MDX-NET-Inst_HQ_5.onnx|slowest", "UVR-MDX-NET-Inst_Main.onnx|slowest", "UVR-MDX-NET-Voc_FT.onnx|slowest", "UVR-MDX-NET_Crowd_HQ_1.onnx|slowest", ]