Spaces:
Runtime error
Runtime error
File size: 2,732 Bytes
e4f9cbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
"""Test the semantic search signal."""
from typing import Iterable, cast
import numpy as np
import pytest
from pytest import approx
from pytest_mock import MockerFixture
from typing_extensions import override
from ..data.dataset_utils import lilac_embedding
from ..embeddings.vector_store import VectorStore
from ..schema import Item, RichData, VectorKey
from .semantic_similarity import SemanticSimilaritySignal
from .signal import TextEmbeddingSignal, clear_signal_registry, register_signal
EMBEDDINGS: dict[VectorKey, list[float]] = {
('1',): [1.0, 0.0, 0.0],
('2',): [0.9, 0.1, 0.0],
('3',): [0.0, 0.0, 1.0]
}
STR_EMBEDDINGS: dict[str, list[float]] = {
'hello': [1.0, 0.0, 0.0],
'hello world': [0.9, 0.1, 0.0],
'far': [0.0, 0.0, 1.0]
}
class TestVectorStore(VectorStore):
"""A test vector store with fixed embeddings."""
@override
def keys(self) -> list[VectorKey]:
return []
@override
def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
# We fix the vectors for the test vector store.
pass
@override
def get(self, keys: Iterable[VectorKey]) -> np.ndarray:
keys = keys or []
return np.array([EMBEDDINGS[row_id] for row_id in keys])
class TestEmbedding(TextEmbeddingSignal):
"""A test embed function."""
name = 'test_embedding'
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
"""Embed the examples, use a hashmap to the vector for simplicity."""
for example in data:
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
@pytest.fixture(scope='module', autouse=True)
def setup_teardown() -> Iterable[None]:
# Setup.
register_signal(TestEmbedding)
# Unit test runs.
yield
# Teardown.
clear_signal_registry()
def test_semantic_similarity_compute_keys(mocker: MockerFixture) -> None:
vector_store = TestVectorStore()
embed_mock = mocker.spy(TestEmbedding, 'compute')
signal = SemanticSimilaritySignal(query='hello', embedding=TestEmbedding.name)
scores = list(signal.vector_compute([('1',), ('2',), ('3',)], vector_store))
# Embeddings should be called only 1 time for the search.
assert embed_mock.call_count == 1
assert scores == [1.0, approx(0.938, 1e-3), approx(0.417, 1e-3)]
def test_semantic_similarity_compute_data(mocker: MockerFixture) -> None:
embed_mock = mocker.spy(TestEmbedding, 'compute')
signal = SemanticSimilaritySignal(query='hello', embedding=TestEmbedding.name)
# Compute over the text.
scores = list(signal.compute(STR_EMBEDDINGS.keys()))
# Embeddings should be called only 2 times, once for the search, once for the query itself.
assert embed_mock.call_count == 2
assert scores == [1.0, 0.9, 0.0]
|