|
import pytest |
|
from openai import OpenAI |
|
from utils import * |
|
|
|
server = ServerPreset.tinyllama2() |
|
|
|
TEST_API_KEY = "sk-this-is-the-secret-key" |
|
|
|
@pytest.fixture(scope="module", autouse=True) |
|
def create_server(): |
|
global server |
|
server = ServerPreset.tinyllama2() |
|
server.api_key = TEST_API_KEY |
|
|
|
|
|
@pytest.mark.parametrize("endpoint", ["/health", "/models"]) |
|
def test_access_public_endpoint(endpoint: str): |
|
global server |
|
server.start() |
|
res = server.make_request("GET", endpoint) |
|
assert res.status_code == 200 |
|
assert "error" not in res.body |
|
|
|
|
|
@pytest.mark.parametrize("api_key", [None, "invalid-key"]) |
|
def test_incorrect_api_key(api_key: str): |
|
global server |
|
server.start() |
|
res = server.make_request("POST", "/completions", data={ |
|
"prompt": "I believe the meaning of life is", |
|
}, headers={ |
|
"Authorization": f"Bearer {api_key}" if api_key else None, |
|
}) |
|
assert res.status_code == 401 |
|
assert "error" in res.body |
|
assert res.body["error"]["type"] == "authentication_error" |
|
|
|
|
|
def test_correct_api_key(): |
|
global server |
|
server.start() |
|
res = server.make_request("POST", "/completions", data={ |
|
"prompt": "I believe the meaning of life is", |
|
}, headers={ |
|
"Authorization": f"Bearer {TEST_API_KEY}", |
|
}) |
|
assert res.status_code == 200 |
|
assert "error" not in res.body |
|
assert "content" in res.body |
|
|
|
|
|
def test_openai_library_correct_api_key(): |
|
global server |
|
server.start() |
|
client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}") |
|
res = client.chat.completions.create( |
|
model="gpt-3.5-turbo", |
|
messages=[ |
|
{"role": "system", "content": "You are a chatbot."}, |
|
{"role": "user", "content": "What is the meaning of life?"}, |
|
], |
|
) |
|
assert len(res.choices) == 1 |
|
|
|
|
|
@pytest.mark.parametrize("origin,cors_header,cors_header_value", [ |
|
("localhost", "Access-Control-Allow-Origin", "localhost"), |
|
("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"), |
|
("origin", "Access-Control-Allow-Credentials", "true"), |
|
("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"), |
|
("web.mydomain.fr", "Access-Control-Allow-Headers", "*"), |
|
]) |
|
def test_cors_options(origin: str, cors_header: str, cors_header_value: str): |
|
global server |
|
server.start() |
|
res = server.make_request("OPTIONS", "/completions", headers={ |
|
"Origin": origin, |
|
"Access-Control-Request-Method": "POST", |
|
"Access-Control-Request-Headers": "Authorization", |
|
}) |
|
assert res.status_code == 200 |
|
assert cors_header in res.headers |
|
assert res.headers[cors_header] == cors_header_value |
|
|