Spaces:
No application file
No application file
| """Tests for the core StemSplitter class.""" | |
| import pytest | |
| from stemsplitter.separator import ( | |
| STEM_LABELS, | |
| OutputFormat, | |
| SeparationResult, | |
| StemMode, | |
| StemSplitter, | |
| ) | |
| class TestStemMode: | |
| def test_two_stem_value(self): | |
| assert StemMode.TWO_STEM.value == "2stem" | |
| def test_four_stem_value(self): | |
| assert StemMode.FOUR_STEM.value == "4stem" | |
| def test_from_string(self): | |
| assert StemMode("2stem") == StemMode.TWO_STEM | |
| assert StemMode("4stem") == StemMode.FOUR_STEM | |
| class TestOutputFormat: | |
| def test_format_values(self): | |
| assert OutputFormat.WAV.value == "WAV" | |
| assert OutputFormat.MP3.value == "MP3" | |
| assert OutputFormat.FLAC.value == "FLAC" | |
| class TestStemLabels: | |
| def test_two_stem_labels(self): | |
| assert STEM_LABELS[StemMode.TWO_STEM] == ["Vocals", "Instrumental"] | |
| def test_four_stem_labels(self): | |
| assert STEM_LABELS[StemMode.FOUR_STEM] == [ | |
| "Vocals", | |
| "Drums", | |
| "Bass", | |
| "Other", | |
| ] | |
| class TestStemSplitter: | |
| def test_separate_2stem(self, mock_separator, test_audio_path, env_settings): | |
| """2-stem separation should return 2 output files.""" | |
| splitter = StemSplitter() | |
| result = splitter.separate( | |
| input_path=test_audio_path, | |
| mode=StemMode.TWO_STEM, | |
| ) | |
| assert isinstance(result, SeparationResult) | |
| assert len(result.output_files) == 2 | |
| assert result.mode == StemMode.TWO_STEM | |
| mock_separator.load_model.assert_called_once() | |
| def test_separate_4stem( | |
| self, mock_separator_4stem, test_audio_path, env_settings | |
| ): | |
| """4-stem separation should return 4 output files.""" | |
| splitter = StemSplitter() | |
| result = splitter.separate( | |
| input_path=test_audio_path, | |
| mode=StemMode.FOUR_STEM, | |
| ) | |
| assert len(result.output_files) == 4 | |
| assert result.mode == StemMode.FOUR_STEM | |
| def test_format_override(self, mock_separator, test_audio_path, env_settings): | |
| """Output format override should be reflected in result.""" | |
| splitter = StemSplitter() | |
| result = splitter.separate( | |
| input_path=test_audio_path, | |
| mode=StemMode.TWO_STEM, | |
| output_format=OutputFormat.FLAC, | |
| ) | |
| assert result.output_format == OutputFormat.FLAC | |
| def test_model_caching(self, mock_separator, test_audio_path, env_settings): | |
| """Same mode twice should NOT reload the model.""" | |
| splitter = StemSplitter() | |
| splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) | |
| splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) | |
| assert mock_separator.load_model.call_count == 1 | |
| def test_model_switch(self, mock_separator, test_audio_path, env_settings): | |
| """Switching modes should reload the model.""" | |
| splitter = StemSplitter() | |
| splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) | |
| splitter.separate(test_audio_path, mode=StemMode.FOUR_STEM) | |
| assert mock_separator.load_model.call_count == 2 | |
| def test_file_not_found(self, env_settings): | |
| """Should raise FileNotFoundError for missing input.""" | |
| splitter = StemSplitter() | |
| with pytest.raises(FileNotFoundError): | |
| splitter.separate("/nonexistent/file.wav") | |
| def test_model_override(self, mock_separator, test_audio_path, env_settings): | |
| """Custom model_override should be passed through.""" | |
| splitter = StemSplitter() | |
| splitter.separate( | |
| test_audio_path, | |
| mode=StemMode.TWO_STEM, | |
| model_override="UVR_MDXNET_KARA_2.onnx", | |
| ) | |
| mock_separator.load_model.assert_called_with( | |
| model_filename="UVR_MDXNET_KARA_2.onnx" | |
| ) | |
| def test_result_contains_input_file( | |
| self, mock_separator, test_audio_path, env_settings | |
| ): | |
| """Result should reference the original input file.""" | |
| splitter = StemSplitter() | |
| result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) | |
| assert result.input_file == str(test_audio_path) | |
| def test_result_contains_model_used( | |
| self, mock_separator, test_audio_path, env_settings | |
| ): | |
| """Result should reference which model was used.""" | |
| splitter = StemSplitter() | |
| result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) | |
| assert "mel_band_roformer" in result.model_used | |
| def test_separation_runtime_error( | |
| self, mock_separator, test_audio_path, env_settings | |
| ): | |
| """RuntimeError should be raised if the underlying separator fails.""" | |
| mock_separator.separate.side_effect = Exception("Model crashed") | |
| splitter = StemSplitter() | |
| with pytest.raises(RuntimeError, match="Separation failed"): | |
| splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) | |