| """ |
| Tests for Together AI model client. |
| """ |
|
|
| import os |
| import sys |
| from pathlib import Path |
| import pytest |
| from unittest.mock import MagicMock, patch |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "stack-2.9-eval")) |
|
|
| from model_client import ( |
| TogetherClient, |
| BaseModelClient, |
| ChatMessage, |
| GenerationResult, |
| create_model_client |
| ) |
|
|
|
|
| def test_together_client_init_with_api_key(): |
| """Test initialization with explicit API key.""" |
| client = TogetherClient(api_key="test-key", model="togethercomputer/Qwen2.5-Coder-32B-Instruct") |
| assert client.api_key == "test-key" |
| assert client.model == "togethercomputer/Qwen2.5-Coder-32B-Instruct" |
| assert client.base_url == "https://api.together.xyz/v1" |
|
|
|
|
| def test_together_client_init_with_env_var(): |
| """Test initialization with environment variable.""" |
| os.environ["TOGETHER_API_KEY"] = "env-key" |
| try: |
| client = TogetherClient() |
| assert client.api_key == "env-key" |
| finally: |
| del os.environ["TOGETHER_API_KEY"] |
|
|
|
|
| def test_together_client_init_without_api_key(): |
| """Test that initialization fails without API key.""" |
| |
| if "TOGETHER_API_KEY" in os.environ: |
| del os.environ["TOGETHER_API_KEY"] |
| with pytest.raises(ValueError, match="Together API key required"): |
| TogetherClient() |
|
|
|
|
| def test_together_client_get_model_name(): |
| """Test get_model_name returns the model.""" |
| client = TogetherClient(api_key="test-key", model="test-model") |
| assert client.get_model_name() == "test-model" |
|
|
|
|
| def test_together_client_generate_success(): |
| """Test successful generate call.""" |
| mock_response = MagicMock() |
| mock_response.choices = [MagicMock()] |
| mock_response.choices[0].text = "Hello, world!" |
| mock_response.choices[0].finish_reason = "stop" |
| mock_response.usage = MagicMock() |
| mock_response.usage.completion_tokens = 5 |
| mock_response.model_dump.return_value = {"mock": "response"} |
|
|
| with patch('model_client.OpenAI') as mock_openai_cls: |
| mock_client = MagicMock() |
| mock_client.completions.create.return_value = mock_response |
| mock_openai_cls.return_value = mock_client |
|
|
| client = TogetherClient(api_key="test-key") |
| result = client.generate("Say hello") |
|
|
| assert result.text == "Hello, world!" |
| assert result.model == client.model |
| assert result.tokens == 5 |
| assert result.finish_reason == "stop" |
| assert result.raw_response == {"mock": "response"} |
| mock_client.completions.create.assert_called_once() |
| call_args = mock_client.completions.create.call_args |
| assert call_args.kwargs["model"] == client.model |
| assert call_args.kwargs["prompt"] == "Say hello" |
|
|
|
|
| def test_together_client_generate_failure(): |
| """Test generate call failure.""" |
| with patch('model_client.OpenAI') as mock_openai_cls: |
| mock_client = MagicMock() |
| mock_client.completions.create.side_effect = Exception("API error") |
| mock_openai_cls.return_value = mock_client |
|
|
| client = TogetherClient(api_key="test-key") |
| with pytest.raises(Exception, match="API error"): |
| client.generate("test") |
|
|
|
|
| def test_together_client_chat_success(): |
| """Test successful chat call.""" |
| mock_response = MagicMock() |
| mock_response.choices = [MagicMock()] |
| mock_response.choices[0].message = MagicMock() |
| mock_response.choices[0].message.content = "Chat response" |
| mock_response.choices[0].finish_reason = "stop" |
| mock_response.usage = MagicMock() |
| mock_response.usage.completion_tokens = 10 |
| mock_response.model_dump.return_value = {"mock": "chat"} |
|
|
| with patch('model_client.OpenAI') as mock_openai_cls: |
| mock_client = MagicMock() |
| mock_client.chat.completions.create.return_value = mock_response |
| mock_openai_cls.return_value = mock_client |
|
|
| client = TogetherClient(api_key="test-key") |
| messages = [ |
| ChatMessage(role="user", content="Hello"), |
| ChatMessage(role="assistant", content="Hi there!"), |
| ] |
| result = client.chat(messages) |
|
|
| assert result.text == "Chat response" |
| assert result.model == client.model |
| assert result.tokens == 10 |
| assert result.finish_reason == "stop" |
| mock_client.chat.completions.create.assert_called_once() |
| call_args = mock_client.chat.completions.create.call_args |
| assert call_args.kwargs["model"] == client.model |
| assert len(call_args.kwargs["messages"]) == 2 |
| assert call_args.kwargs["messages"][0] == {"role": "user", "content": "Hello"} |
| assert call_args.kwargs["messages"][1] == {"role": "assistant", "content": "Hi there!"} |
|
|
|
|
| def test_together_client_chat_with_tools(): |
| """Test chat call with tools.""" |
| mock_response = MagicMock() |
| mock_response.choices = [MagicMock()] |
| mock_response.choices[0].message = MagicMock() |
| mock_response.choices[0].message.content = "" |
| mock_response.choices[0].finish_reason = "tool_calls" |
| mock_response.usage = MagicMock() |
| mock_response.usage.completion_tokens = 0 |
| mock_response.model_dump.return_value = {} |
|
|
| with patch('model_client.OpenAI') as mock_openai_cls: |
| mock_client = MagicMock() |
| mock_client.chat.completions.create.return_value = mock_response |
| mock_openai_cls.return_value = mock_client |
|
|
| client = TogetherClient(api_key="test-key") |
| messages = [ChatMessage(role="user", content="What's the weather?")] |
| tools = [{ |
| "type": "function", |
| "function": { |
| "name": "get_weather", |
| "description": "Get weather for a location", |
| "parameters": {} |
| } |
| }] |
| result = client.chat(messages, tools=tools) |
|
|
| assert result.text == "" |
| call_args = mock_client.chat.completions.create.call_args |
| assert "tools" in call_args.kwargs |
| assert call_args.kwargs["tools"] == tools |
|
|
|
|
| def test_together_client_base_url(): |
| """Test that the client uses Together's base URL.""" |
| client = TogetherClient(api_key="test-key") |
| assert client.base_url == "https://api.together.xyz/v1" |
| |
| with patch('model_client.OpenAI') as mock_openai_cls: |
| mock_client = MagicMock() |
| mock_client.completions.create.return_value = MagicMock( |
| choices=[MagicMock(text="ok", finish_reason="stop")], |
| usage=MagicMock(completion_tokens=1), |
| model_dump=lambda: {} |
| ) |
| mock_openai_cls.return_value = mock_client |
| client.generate("test") |
| mock_openai_cls.assert_called_once_with( |
| api_key="test-key", |
| base_url="https://api.together.xyz/v1", |
| timeout=120 |
| ) |
|
|
|
|
| def test_together_client_default_model(): |
| """Test default model when none provided.""" |
| |
| client = TogetherClient(api_key="test-key") |
| assert client.model == "togethercomputer/Qwen2.5-Coder-32B-Instruct" |
|
|
| |
| os.environ["TOGETHER_MODEL"] = "custom/model" |
| try: |
| client = TogetherClient(api_key="test-key") |
| assert client.model == "custom/model" |
| finally: |
| del os.environ["TOGETHER_MODEL"] |
|
|
|
|
| def test_create_model_client_together(): |
| """Test factory function creates TogetherClient.""" |
| client = create_model_client("together", api_key="test-key") |
| assert isinstance(client, TogetherClient) |
| assert client.api_key == "test-key" |
| assert client.model == "togethercomputer/Qwen2.5-Coder-32B-Instruct" |
|
|
| |
| client = create_model_client("together", model="custom/model", api_key="key") |
| assert client.model == "custom/model" |
|
|
|
|
| def test_create_model_client_together_from_env(): |
| """Test factory reads env vars.""" |
| os.environ["TOGETHER_API_KEY"] = "env-key" |
| os.environ["TOGETHER_MODEL"] = "env/model" |
| try: |
| client = create_model_client("together") |
| assert client.api_key == "env-key" |
| assert client.model == "env/model" |
| finally: |
| del os.environ["TOGETHER_API_KEY"] |
| del os.environ["TOGETHER_MODEL"] |
|
|