File size: 5,658 Bytes
8c0b652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
OpenAI API compatibility tests
"""
import pytest
import requests
import json
from typing import Dict, Any


class TestOpenAICompatibility:
    """Test OpenAI-compatible endpoints"""
    
    @pytest.fixture
    def base_url(self):
        return "http://localhost:8000"
    
    def test_chat_completions(self, base_url):
        """Test OpenAI chat completions endpoint"""
        payload = {
            "model": "llama3.1-8b",
            "messages": [
                {"role": "user", "content": "What is EBITDA?"}
            ],
            "max_tokens": 100,
            "temperature": 0.7
        }
        response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
        assert response.status_code == 200
        data = response.json()
        
        # Check OpenAI response format
        assert "choices" in data
        assert "usage" in data
        assert "model" in data
        assert len(data["choices"]) > 0
        assert "message" in data["choices"][0]
        assert "content" in data["choices"][0]["message"]
    
    def test_chat_completions_with_system_message(self, base_url):
        """Test chat completions with system message"""
        payload = {
            "model": "llama3.1-8b",
            "messages": [
                {"role": "system", "content": "You are a financial expert."},
                {"role": "user", "content": "Explain the difference between revenue and profit."}
            ],
            "max_tokens": 150,
            "temperature": 0.6
        }
        response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
        assert response.status_code == 200
        data = response.json()
        assert "choices" in data
        assert len(data["choices"][0]["message"]["content"]) > 0
    
    def test_text_completions(self, base_url):
        """Test OpenAI text completions endpoint"""
        payload = {
            "model": "llama3.1-8b",
            "prompt": "The key financial ratios for a healthy company include:",
            "max_tokens": 100,
            "temperature": 0.5
        }
        response = requests.post(f"{base_url}/v1/completions", json=payload)
        assert response.status_code == 200
        data = response.json()
        
        # Check OpenAI response format
        assert "choices" in data
        assert "usage" in data
        assert "model" in data
        assert len(data["choices"]) > 0
        assert "text" in data["choices"][0]
    
    def test_json_response_format(self, base_url):
        """Test structured JSON output"""
        payload = {
            "model": "llama3.1-8b",
            "messages": [
                {
                    "role": "user", 
                    "content": "Return financial metrics in JSON format: revenue, profit, debt_ratio"
                }
            ],
            "response_format": {"type": "json_object"},
            "max_tokens": 150
        }
        response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
        assert response.status_code == 200
        data = response.json()
        
        # Check that response is valid JSON
        content = data["choices"][0]["message"]["content"]
        try:
            json_data = json.loads(content)
            assert isinstance(json_data, dict)
        except json.JSONDecodeError:
            pytest.fail("Response is not valid JSON")
    
    def test_streaming_response(self, base_url):
        """Test streaming chat completions"""
        payload = {
            "model": "llama3.1-8b",
            "messages": [
                {"role": "user", "content": "Explain financial risk management strategies."}
            ],
            "stream": True,
            "max_tokens": 100
        }
        response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=True)
        assert response.status_code == 200
        assert response.headers["content-type"] == "text/event-stream"
        
        # Check that we get streaming data
        chunks = []
        for line in response.iter_lines():
            if line:
                chunks.append(line.decode('utf-8'))
                if len(chunks) >= 3:  # Get a few chunks
                    break
        
        assert len(chunks) > 0
        # Check SSE format
        assert any("data:" in chunk for chunk in chunks)
    
    def test_rag_example(self, base_url):
        """Test RAG-style document analysis"""
        payload = {
            "model": "llama3.1-8b",
            "messages": [
                {
                    "role": "system", 
                    "content": "You are a financial analyst. Extract key metrics from financial documents."
                },
                {
                    "role": "user", 
                    "content": "Analyze this financial statement and return the data in JSON format with fields: revenue, expenses, net_income.\n\nDocument:\nQ3 2024 Results:\nRevenue: $2.5M\nExpenses: $1.8M\nNet Income: $700K"
                }
            ],
            "response_format": {"type": "json_object"},
            "max_tokens": 200
        }
        response = requests.post(f"{base_url}/v1/chat/completions", json=payload)
        assert response.status_code == 200
        data = response.json()
        
        content = data["choices"][0]["message"]["content"]
        json_data = json.loads(content)
        
        # Check that key financial metrics are extracted
        expected_fields = ["revenue", "expenses", "net_income"]
        for field in expected_fields:
            assert field in json_data or field.replace("_", " ") in str(json_data).lower()