Spaces:
Runtime error
Runtime error
| """ | |
| OpenAI API compatibility tests | |
| """ | |
| import pytest | |
| import requests | |
| import json | |
| from typing import Dict, Any | |
| class TestOpenAICompatibility: | |
| """Test OpenAI-compatible endpoints""" | |
| def base_url(self): | |
| return "http://localhost:8000" | |
| def test_chat_completions(self, base_url): | |
| """Test OpenAI chat completions endpoint""" | |
| payload = { | |
| "model": "llama3.1-8b", | |
| "messages": [ | |
| {"role": "user", "content": "What is EBITDA?"} | |
| ], | |
| "max_tokens": 100, | |
| "temperature": 0.7 | |
| } | |
| response = requests.post(f"{base_url}/v1/chat/completions", json=payload) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Check OpenAI response format | |
| assert "choices" in data | |
| assert "usage" in data | |
| assert "model" in data | |
| assert len(data["choices"]) > 0 | |
| assert "message" in data["choices"][0] | |
| assert "content" in data["choices"][0]["message"] | |
| def test_chat_completions_with_system_message(self, base_url): | |
| """Test chat completions with system message""" | |
| payload = { | |
| "model": "llama3.1-8b", | |
| "messages": [ | |
| {"role": "system", "content": "You are a financial expert."}, | |
| {"role": "user", "content": "Explain the difference between revenue and profit."} | |
| ], | |
| "max_tokens": 150, | |
| "temperature": 0.6 | |
| } | |
| response = requests.post(f"{base_url}/v1/chat/completions", json=payload) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "choices" in data | |
| assert len(data["choices"][0]["message"]["content"]) > 0 | |
| def test_text_completions(self, base_url): | |
| """Test OpenAI text completions endpoint""" | |
| payload = { | |
| "model": "llama3.1-8b", | |
| "prompt": "The key financial ratios for a healthy company include:", | |
| "max_tokens": 100, | |
| "temperature": 0.5 | |
| } | |
| response = requests.post(f"{base_url}/v1/completions", json=payload) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Check OpenAI response format | |
| assert "choices" in data | |
| assert "usage" in data | |
| assert "model" in data | |
| assert len(data["choices"]) > 0 | |
| assert "text" in data["choices"][0] | |
| def test_json_response_format(self, base_url): | |
| """Test structured JSON output""" | |
| payload = { | |
| "model": "llama3.1-8b", | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": "Return financial metrics in JSON format: revenue, profit, debt_ratio" | |
| } | |
| ], | |
| "response_format": {"type": "json_object"}, | |
| "max_tokens": 150 | |
| } | |
| response = requests.post(f"{base_url}/v1/chat/completions", json=payload) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| # Check that response is valid JSON | |
| content = data["choices"][0]["message"]["content"] | |
| try: | |
| json_data = json.loads(content) | |
| assert isinstance(json_data, dict) | |
| except json.JSONDecodeError: | |
| pytest.fail("Response is not valid JSON") | |
| def test_streaming_response(self, base_url): | |
| """Test streaming chat completions""" | |
| payload = { | |
| "model": "llama3.1-8b", | |
| "messages": [ | |
| {"role": "user", "content": "Explain financial risk management strategies."} | |
| ], | |
| "stream": True, | |
| "max_tokens": 100 | |
| } | |
| response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=True) | |
| assert response.status_code == 200 | |
| assert response.headers["content-type"] == "text/event-stream" | |
| # Check that we get streaming data | |
| chunks = [] | |
| for line in response.iter_lines(): | |
| if line: | |
| chunks.append(line.decode('utf-8')) | |
| if len(chunks) >= 3: # Get a few chunks | |
| break | |
| assert len(chunks) > 0 | |
| # Check SSE format | |
| assert any("data:" in chunk for chunk in chunks) | |
| def test_rag_example(self, base_url): | |
| """Test RAG-style document analysis""" | |
| payload = { | |
| "model": "llama3.1-8b", | |
| "messages": [ | |
| { | |
| "role": "system", | |
| "content": "You are a financial analyst. Extract key metrics from financial documents." | |
| }, | |
| { | |
| "role": "user", | |
| "content": "Analyze this financial statement and return the data in JSON format with fields: revenue, expenses, net_income.\n\nDocument:\nQ3 2024 Results:\nRevenue: $2.5M\nExpenses: $1.8M\nNet Income: $700K" | |
| } | |
| ], | |
| "response_format": {"type": "json_object"}, | |
| "max_tokens": 200 | |
| } | |
| response = requests.post(f"{base_url}/v1/chat/completions", json=payload) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| content = data["choices"][0]["message"]["content"] | |
| json_data = json.loads(content) | |
| # Check that key financial metrics are extracted | |
| expected_fields = ["revenue", "expenses", "net_income"] | |
| for field in expected_fields: | |
| assert field in json_data or field.replace("_", " ") in str(json_data).lower() | |