Spaces:
Runtime error
Runtime error
File size: 5,658 Bytes
8c0b652 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""
OpenAI API compatibility tests
"""
import pytest
import requests
import json
from typing import Dict, Any
class TestOpenAICompatibility:
"""Test OpenAI-compatible endpoints"""
@pytest.fixture
def base_url(self):
return "http://localhost:8000"
def test_chat_completions(self, base_url):
"""Test OpenAI chat completions endpoint"""
payload = {
"model": "llama3.1-8b",
"messages": [
{"role": "user", "content": "What is EBITDA?"}
],
"max_tokens": 100,
"temperature": 0.7
}
response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
assert response.status_code == 200
data = response.json()
# Check OpenAI response format
assert "choices" in data
assert "usage" in data
assert "model" in data
assert len(data["choices"]) > 0
assert "message" in data["choices"][0]
assert "content" in data["choices"][0]["message"]
def test_chat_completions_with_system_message(self, base_url):
"""Test chat completions with system message"""
payload = {
"model": "llama3.1-8b",
"messages": [
{"role": "system", "content": "You are a financial expert."},
{"role": "user", "content": "Explain the difference between revenue and profit."}
],
"max_tokens": 150,
"temperature": 0.6
}
response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
assert response.status_code == 200
data = response.json()
assert "choices" in data
assert len(data["choices"][0]["message"]["content"]) > 0
def test_text_completions(self, base_url):
"""Test OpenAI text completions endpoint"""
payload = {
"model": "llama3.1-8b",
"prompt": "The key financial ratios for a healthy company include:",
"max_tokens": 100,
"temperature": 0.5
}
response = requests.post(f"{base_url}/v1/completions", json=payload)
assert response.status_code == 200
data = response.json()
# Check OpenAI response format
assert "choices" in data
assert "usage" in data
assert "model" in data
assert len(data["choices"]) > 0
assert "text" in data["choices"][0]
def test_json_response_format(self, base_url):
"""Test structured JSON output"""
payload = {
"model": "llama3.1-8b",
"messages": [
{
"role": "user",
"content": "Return financial metrics in JSON format: revenue, profit, debt_ratio"
}
],
"response_format": {"type": "json_object"},
"max_tokens": 150
}
response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
assert response.status_code == 200
data = response.json()
# Check that response is valid JSON
content = data["choices"][0]["message"]["content"]
try:
json_data = json.loads(content)
assert isinstance(json_data, dict)
except json.JSONDecodeError:
pytest.fail("Response is not valid JSON")
def test_streaming_response(self, base_url):
"""Test streaming chat completions"""
payload = {
"model": "llama3.1-8b",
"messages": [
{"role": "user", "content": "Explain financial risk management strategies."}
],
"stream": True,
"max_tokens": 100
}
response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=True)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream"
# Check that we get streaming data
chunks = []
for line in response.iter_lines():
if line:
chunks.append(line.decode('utf-8'))
if len(chunks) >= 3: # Get a few chunks
break
assert len(chunks) > 0
# Check SSE format
assert any("data:" in chunk for chunk in chunks)
def test_rag_example(self, base_url):
"""Test RAG-style document analysis"""
payload = {
"model": "llama3.1-8b",
"messages": [
{
"role": "system",
"content": "You are a financial analyst. Extract key metrics from financial documents."
},
{
"role": "user",
"content": "Analyze this financial statement and return the data in JSON format with fields: revenue, expenses, net_income.\n\nDocument:\nQ3 2024 Results:\nRevenue: $2.5M\nExpenses: $1.8M\nNet Income: $700K"
}
],
"response_format": {"type": "json_object"},
"max_tokens": 200
}
response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
assert response.status_code == 200
data = response.json()
content = data["choices"][0]["message"]["content"]
json_data = json.loads(content)
# Check that key financial metrics are extracted
expected_fields = ["revenue", "expenses", "net_income"]
for field in expected_fields:
assert field in json_data or field.replace("_", " ") in str(json_data).lower()
|