Spaces:
Sleeping
Sleeping
File size: 3,672 Bytes
23cd97b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import pytest
from unittest.mock import patch, Mock
from api.audio import TTSManager, APIError
class TestTTSManager:
def setup_method(self):
self.config = Mock()
self.config.tts.key = "test-key"
self.config.tts.url = "https://api.example.com"
self.config.tts.name = "test-tts-model"
self.config.tts.type = "OPENAI_API"
self.tts_manager = TTSManager(self.config)
@patch("requests.post")
@pytest.mark.parametrize("stream", [False, True])
def test_read_text(self, mock_post, stream):
self.tts_manager.streaming = stream
if stream:
mock_response = Mock()
mock_response.status_code = 200
mock_response.iter_content = Mock(return_value=[b"audio-bytes-part1", b"audio-bytes-part2"])
mock_post.return_value.__enter__ = Mock(return_value=mock_response)
mock_post.return_value.__exit__ = Mock(return_value=None)
result = list(self.tts_manager.read_text("Hello, world!"))
assert result == [b"audio-bytes-part1", b"audio-bytes-part2"]
else:
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"audio-bytes"
mock_post.return_value = mock_response
result = list(self.tts_manager.read_text("Hello, world!"))
assert result == [b"audio-bytes"]
@patch("requests.post")
@pytest.mark.parametrize("stream", [False, True])
def test_read_text_error(self, mock_post, stream):
self.tts_manager.streaming = stream
mock_response = Mock()
mock_response.status_code = 500
mock_response.json.return_value = {"error": "Internal Server Error"}
mock_post.return_value.__enter__ = Mock(return_value=mock_response)
mock_post.return_value.__exit__ = Mock(return_value=None)
with pytest.raises(APIError):
list(self.tts_manager.read_text("Hello, world!"))
@patch("requests.post")
@pytest.mark.parametrize("stream", [False, True])
def test_read_last_message(self, mock_post, stream):
self.tts_manager.streaming = stream
chat_history = [["user", "Hello, world!"]]
if stream:
mock_response = Mock()
mock_response.status_code = 200
mock_response.iter_content = Mock(return_value=[b"audio-bytes-part1", b"audio-bytes-part2"])
mock_post.return_value.__enter__ = Mock(return_value=mock_response)
mock_post.return_value.__exit__ = Mock(return_value=None)
result = list(self.tts_manager.read_last_message(chat_history))
assert result == [b"audio-bytes-part1", b"audio-bytes-part2"]
else:
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = b"audio-bytes"
mock_post.return_value = mock_response
result = list(self.tts_manager.read_last_message(chat_history))
assert result == [b"audio-bytes"]
@patch("requests.post")
@pytest.mark.parametrize("stream", [False, True])
def test_read_last_message_error(self, mock_post, stream):
self.tts_manager.streaming = stream
chat_history = [["user", "Hello, world!"]]
mock_response = Mock()
mock_response.status_code = 500
mock_response.json.return_value = {"error": "Internal Server Error"}
mock_post.return_value.__enter__ = Mock(return_value=mock_response)
mock_post.return_value.__exit__ = Mock(return_value=None)
with pytest.raises(APIError):
list(self.tts_manager.read_last_message(chat_history))
|