from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel from spleeter.separator import Separator class SpleeterConfig(PretrainedConfig): model_type = "fsmt" def __init__(self, stems=2, **kwargs): super().__init__(**kwargs) self.stems = stems class SpleeterModel(PreTrainedModel): config_class = SpleeterConfig def __init__(self, config): super().__init__(config) self.separator = Separator(f"{config.stems}stems") def forward(self, audio_path: str): """ Separates the stems in the given audio file. Args: audio_path (str): Path to the input audio file. Returns: dict: Separated stems. """ return self.separator.separate(audio_path) AutoConfig.register("spleeter", SpleeterConfig) AutoModel.register(SpleeterConfig, SpleeterModel)