colbert-xm-for-inference-api / test_endpoint.py
fdurant's picture
Add handler.py, start_emulator.sh and test scripts
4c96de6
raw
history blame
2.44 kB
import os
import pytest
import requests
URL = "http://localhost:4999/"
HEADERS = {"Content-Type": "application/json"}
def test_returns_200():
payload = {"inputs": "try me"}
response = requests.request("POST", URL, json=payload, headers=HEADERS)
assert response.status_code == 200
def test_query_returns_expected_result():
query = "try me"
payload = {"inputs": query}
response = requests.request("POST", URL, json=payload, headers=HEADERS)
response_data = response.json()
# print(response_data)
# Check structure and input
assert isinstance(response_data, list)
assert len(response_data) == 1
assert isinstance(response_data[0], dict)
assert response_data[0].get("input") == query
# Check query embedding (actually a list of embeddings, one per token in the query)
query_embedding = response_data[0].get("query_embedding")
assert isinstance(query_embedding, list)
assert len(query_embedding) == 32
# Check first of the token embeddings
first_token_embedding = query_embedding[0]
assert isinstance(first_token_embedding, list)
assert len(first_token_embedding) == 128
assert all(isinstance(value, float) for value in first_token_embedding)
def test_batch_returns_expected_result():
chunks = ["try me", "try me again and again and again"]
expected_token_counts = [11, 11] # Including start and stop tokens, I presume. Not exactly clear!
payload = {"inputs": chunks}
response = requests.request("POST", URL, json=payload, headers=HEADERS)
response_data = response.json()
# Check structure
assert isinstance(response_data, list)
assert len(response_data) == len(chunks)
for i, response_chunk in enumerate(response_data):
# Check input
assert response_chunk.get("input") == chunks[i]
# Check chunk embedding (actually a list of embeddings, one per token in the chunk)
chunk_embedding = response_chunk.get("chunk_embedding")
token_count = response_chunk.get("token_count")
assert isinstance(chunk_embedding, list)
assert len(chunk_embedding) == len(token_count)
assert len(token_count) == expected_token_counts[i]
# Check first of the token embeddings
first_token_embedding = chunk_embedding[0]
assert len(first_token_embedding) == 128
assert all(isinstance(value, float) for value in first_token_embedding)