StemSplitter / tests /test_separator.py
ymcnabb's picture
Upload folder using huggingface_hub
1824ea0 verified
"""Tests for the core StemSplitter class."""
import pytest
from stemsplitter.separator import (
STEM_LABELS,
OutputFormat,
SeparationResult,
StemMode,
StemSplitter,
)
class TestStemMode:
def test_two_stem_value(self):
assert StemMode.TWO_STEM.value == "2stem"
def test_four_stem_value(self):
assert StemMode.FOUR_STEM.value == "4stem"
def test_from_string(self):
assert StemMode("2stem") == StemMode.TWO_STEM
assert StemMode("4stem") == StemMode.FOUR_STEM
class TestOutputFormat:
def test_format_values(self):
assert OutputFormat.WAV.value == "WAV"
assert OutputFormat.MP3.value == "MP3"
assert OutputFormat.FLAC.value == "FLAC"
class TestStemLabels:
def test_two_stem_labels(self):
assert STEM_LABELS[StemMode.TWO_STEM] == ["Vocals", "Instrumental"]
def test_four_stem_labels(self):
assert STEM_LABELS[StemMode.FOUR_STEM] == [
"Vocals",
"Drums",
"Bass",
"Other",
]
class TestStemSplitter:
def test_separate_2stem(self, mock_separator, test_audio_path, env_settings):
"""2-stem separation should return 2 output files."""
splitter = StemSplitter()
result = splitter.separate(
input_path=test_audio_path,
mode=StemMode.TWO_STEM,
)
assert isinstance(result, SeparationResult)
assert len(result.output_files) == 2
assert result.mode == StemMode.TWO_STEM
mock_separator.load_model.assert_called_once()
def test_separate_4stem(
self, mock_separator_4stem, test_audio_path, env_settings
):
"""4-stem separation should return 4 output files."""
splitter = StemSplitter()
result = splitter.separate(
input_path=test_audio_path,
mode=StemMode.FOUR_STEM,
)
assert len(result.output_files) == 4
assert result.mode == StemMode.FOUR_STEM
def test_format_override(self, mock_separator, test_audio_path, env_settings):
"""Output format override should be reflected in result."""
splitter = StemSplitter()
result = splitter.separate(
input_path=test_audio_path,
mode=StemMode.TWO_STEM,
output_format=OutputFormat.FLAC,
)
assert result.output_format == OutputFormat.FLAC
def test_model_caching(self, mock_separator, test_audio_path, env_settings):
"""Same mode twice should NOT reload the model."""
splitter = StemSplitter()
splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
assert mock_separator.load_model.call_count == 1
def test_model_switch(self, mock_separator, test_audio_path, env_settings):
"""Switching modes should reload the model."""
splitter = StemSplitter()
splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
splitter.separate(test_audio_path, mode=StemMode.FOUR_STEM)
assert mock_separator.load_model.call_count == 2
def test_file_not_found(self, env_settings):
"""Should raise FileNotFoundError for missing input."""
splitter = StemSplitter()
with pytest.raises(FileNotFoundError):
splitter.separate("/nonexistent/file.wav")
def test_model_override(self, mock_separator, test_audio_path, env_settings):
"""Custom model_override should be passed through."""
splitter = StemSplitter()
splitter.separate(
test_audio_path,
mode=StemMode.TWO_STEM,
model_override="UVR_MDXNET_KARA_2.onnx",
)
mock_separator.load_model.assert_called_with(
model_filename="UVR_MDXNET_KARA_2.onnx"
)
def test_result_contains_input_file(
self, mock_separator, test_audio_path, env_settings
):
"""Result should reference the original input file."""
splitter = StemSplitter()
result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
assert result.input_file == str(test_audio_path)
def test_result_contains_model_used(
self, mock_separator, test_audio_path, env_settings
):
"""Result should reference which model was used."""
splitter = StemSplitter()
result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
assert "mel_band_roformer" in result.model_used
def test_separation_runtime_error(
self, mock_separator, test_audio_path, env_settings
):
"""RuntimeError should be raised if the underlying separator fails."""
mock_separator.separate.side_effect = Exception("Model crashed")
splitter = StemSplitter()
with pytest.raises(RuntimeError, match="Separation failed"):
splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)