File size: 2,437 Bytes
4c96de6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import pytest
import requests
URL = "http://localhost:4999/"
HEADERS = {"Content-Type": "application/json"}
def test_returns_200():
payload = {"inputs": "try me"}
response = requests.request("POST", URL, json=payload, headers=HEADERS)
assert response.status_code == 200
def test_query_returns_expected_result():
query = "try me"
payload = {"inputs": query}
response = requests.request("POST", URL, json=payload, headers=HEADERS)
response_data = response.json()
# print(response_data)
# Check structure and input
assert isinstance(response_data, list)
assert len(response_data) == 1
assert isinstance(response_data[0], dict)
assert response_data[0].get("input") == query
# Check query embedding (actually a list of embeddings, one per token in the query)
query_embedding = response_data[0].get("query_embedding")
assert isinstance(query_embedding, list)
assert len(query_embedding) == 32
# Check first of the token embeddings
first_token_embedding = query_embedding[0]
assert isinstance(first_token_embedding, list)
assert len(first_token_embedding) == 128
assert all(isinstance(value, float) for value in first_token_embedding)
def test_batch_returns_expected_result():
chunks = ["try me", "try me again and again and again"]
expected_token_counts = [11, 11] # Including start and stop tokens, I presume. Not exactly clear!
payload = {"inputs": chunks}
response = requests.request("POST", URL, json=payload, headers=HEADERS)
response_data = response.json()
# Check structure
assert isinstance(response_data, list)
assert len(response_data) == len(chunks)
for i, response_chunk in enumerate(response_data):
# Check input
assert response_chunk.get("input") == chunks[i]
# Check chunk embedding (actually a list of embeddings, one per token in the chunk)
chunk_embedding = response_chunk.get("chunk_embedding")
token_count = response_chunk.get("token_count")
assert isinstance(chunk_embedding, list)
assert len(chunk_embedding) == len(token_count)
assert len(token_count) == expected_token_counts[i]
# Check first of the token embeddings
first_token_embedding = chunk_embedding[0]
assert len(first_token_embedding) == 128
assert all(isinstance(value, float) for value in first_token_embedding)
|