File size: 2,437 Bytes
4c96de6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import pytest
import requests

URL = "http://localhost:4999/"
HEADERS = {"Content-Type": "application/json"}

def test_returns_200():
    payload = {"inputs": "try me"}

    response = requests.request("POST", URL, json=payload, headers=HEADERS)

    assert response.status_code == 200

def test_query_returns_expected_result():
    query = "try me"
    payload = {"inputs": query}

    response = requests.request("POST", URL, json=payload, headers=HEADERS)
    response_data = response.json()

#    print(response_data)

    # Check structure and input
    assert isinstance(response_data, list)
    assert len(response_data) == 1
    assert isinstance(response_data[0], dict)
    assert response_data[0].get("input") == query

    # Check query embedding (actually a list of embeddings, one per token in the query)
    query_embedding = response_data[0].get("query_embedding")
    assert isinstance(query_embedding, list)
    assert len(query_embedding) == 32

    # Check first of the token embeddings
    first_token_embedding = query_embedding[0]
    assert isinstance(first_token_embedding, list)
    assert len(first_token_embedding) == 128
    assert all(isinstance(value, float) for value in first_token_embedding)

def test_batch_returns_expected_result():
    chunks = ["try me", "try me again and again and again"]
    expected_token_counts = [11, 11]  # Including start and stop tokens, I presume. Not exactly clear!
    payload = {"inputs": chunks}

    response = requests.request("POST", URL, json=payload, headers=HEADERS)
    response_data = response.json()

    # Check structure
    assert isinstance(response_data, list)
    assert len(response_data) == len(chunks)

    for i, response_chunk in enumerate(response_data):
        # Check input
        assert response_chunk.get("input") == chunks[i]

        # Check chunk embedding (actually a list of embeddings, one per token in the chunk)
        chunk_embedding = response_chunk.get("chunk_embedding")
        token_count = response_chunk.get("token_count")
        assert isinstance(chunk_embedding, list)
        assert len(chunk_embedding) == len(token_count)
        assert len(token_count) == expected_token_counts[i]

        # Check first of the token embeddings
        first_token_embedding = chunk_embedding[0]
        assert len(first_token_embedding) == 128
        assert all(isinstance(value, float) for value in first_token_embedding)