Stack-2-9-finetuned / samples /unit /test_together_client.py
walidsobhie-code
refactor: Squeeze folders further - cleaner structure
65888d5
"""
Tests for Together AI model client.
"""
import os
import sys
from pathlib import Path
import pytest
from unittest.mock import MagicMock, patch
# Add stack-2.9-eval directory to path to import model_client directly
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "stack-2.9-eval"))
from model_client import (
TogetherClient,
BaseModelClient,
ChatMessage,
GenerationResult,
create_model_client
)
def test_together_client_init_with_api_key():
"""Test initialization with explicit API key."""
client = TogetherClient(api_key="test-key", model="togethercomputer/Qwen2.5-Coder-32B-Instruct")
assert client.api_key == "test-key"
assert client.model == "togethercomputer/Qwen2.5-Coder-32B-Instruct"
assert client.base_url == "https://api.together.xyz/v1"
def test_together_client_init_with_env_var():
"""Test initialization with environment variable."""
os.environ["TOGETHER_API_KEY"] = "env-key"
try:
client = TogetherClient()
assert client.api_key == "env-key"
finally:
del os.environ["TOGETHER_API_KEY"]
def test_together_client_init_without_api_key():
"""Test that initialization fails without API key."""
# Ensure env var is not set
if "TOGETHER_API_KEY" in os.environ:
del os.environ["TOGETHER_API_KEY"]
with pytest.raises(ValueError, match="Together API key required"):
TogetherClient()
def test_together_client_get_model_name():
"""Test get_model_name returns the model."""
client = TogetherClient(api_key="test-key", model="test-model")
assert client.get_model_name() == "test-model"
def test_together_client_generate_success():
"""Test successful generate call."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].text = "Hello, world!"
mock_response.choices[0].finish_reason = "stop"
mock_response.usage = MagicMock()
mock_response.usage.completion_tokens = 5
mock_response.model_dump.return_value = {"mock": "response"}
with patch('model_client.OpenAI') as mock_openai_cls:
mock_client = MagicMock()
mock_client.completions.create.return_value = mock_response
mock_openai_cls.return_value = mock_client
client = TogetherClient(api_key="test-key")
result = client.generate("Say hello")
assert result.text == "Hello, world!"
assert result.model == client.model
assert result.tokens == 5
assert result.finish_reason == "stop"
assert result.raw_response == {"mock": "response"}
mock_client.completions.create.assert_called_once()
call_args = mock_client.completions.create.call_args
assert call_args.kwargs["model"] == client.model
assert call_args.kwargs["prompt"] == "Say hello"
def test_together_client_generate_failure():
"""Test generate call failure."""
with patch('model_client.OpenAI') as mock_openai_cls:
mock_client = MagicMock()
mock_client.completions.create.side_effect = Exception("API error")
mock_openai_cls.return_value = mock_client
client = TogetherClient(api_key="test-key")
with pytest.raises(Exception, match="API error"):
client.generate("test")
def test_together_client_chat_success():
"""Test successful chat call."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = "Chat response"
mock_response.choices[0].finish_reason = "stop"
mock_response.usage = MagicMock()
mock_response.usage.completion_tokens = 10
mock_response.model_dump.return_value = {"mock": "chat"}
with patch('model_client.OpenAI') as mock_openai_cls:
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
mock_openai_cls.return_value = mock_client
client = TogetherClient(api_key="test-key")
messages = [
ChatMessage(role="user", content="Hello"),
ChatMessage(role="assistant", content="Hi there!"),
]
result = client.chat(messages)
assert result.text == "Chat response"
assert result.model == client.model
assert result.tokens == 10
assert result.finish_reason == "stop"
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
assert call_args.kwargs["model"] == client.model
assert len(call_args.kwargs["messages"]) == 2
assert call_args.kwargs["messages"][0] == {"role": "user", "content": "Hello"}
assert call_args.kwargs["messages"][1] == {"role": "assistant", "content": "Hi there!"}
def test_together_client_chat_with_tools():
"""Test chat call with tools."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = ""
mock_response.choices[0].finish_reason = "tool_calls"
mock_response.usage = MagicMock()
mock_response.usage.completion_tokens = 0
mock_response.model_dump.return_value = {}
with patch('model_client.OpenAI') as mock_openai_cls:
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = mock_response
mock_openai_cls.return_value = mock_client
client = TogetherClient(api_key="test-key")
messages = [ChatMessage(role="user", content="What's the weather?")]
tools = [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a location",
"parameters": {}
}
}]
result = client.chat(messages, tools=tools)
assert result.text == ""
call_args = mock_client.chat.completions.create.call_args
assert "tools" in call_args.kwargs
assert call_args.kwargs["tools"] == tools
def test_together_client_base_url():
"""Test that the client uses Together's base URL."""
client = TogetherClient(api_key="test-key")
assert client.base_url == "https://api.together.xyz/v1"
# Check that the OpenAI client is initialized with this base_url
with patch('model_client.OpenAI') as mock_openai_cls:
mock_client = MagicMock()
mock_client.completions.create.return_value = MagicMock(
choices=[MagicMock(text="ok", finish_reason="stop")],
usage=MagicMock(completion_tokens=1),
model_dump=lambda: {}
)
mock_openai_cls.return_value = mock_client
client.generate("test")
mock_openai_cls.assert_called_once_with(
api_key="test-key",
base_url="https://api.together.xyz/v1",
timeout=120
)
def test_together_client_default_model():
"""Test default model when none provided."""
# Without env var
client = TogetherClient(api_key="test-key")
assert client.model == "togethercomputer/Qwen2.5-Coder-32B-Instruct"
# With env var
os.environ["TOGETHER_MODEL"] = "custom/model"
try:
client = TogetherClient(api_key="test-key")
assert client.model == "custom/model"
finally:
del os.environ["TOGETHER_MODEL"]
def test_create_model_client_together():
"""Test factory function creates TogetherClient."""
client = create_model_client("together", api_key="test-key")
assert isinstance(client, TogetherClient)
assert client.api_key == "test-key"
assert client.model == "togethercomputer/Qwen2.5-Coder-32B-Instruct"
# Test with custom model
client = create_model_client("together", model="custom/model", api_key="key")
assert client.model == "custom/model"
def test_create_model_client_together_from_env():
"""Test factory reads env vars."""
os.environ["TOGETHER_API_KEY"] = "env-key"
os.environ["TOGETHER_MODEL"] = "env/model"
try:
client = create_model_client("together")
assert client.api_key == "env-key"
assert client.model == "env/model"
finally:
del os.environ["TOGETHER_API_KEY"]
del os.environ["TOGETHER_MODEL"]