nikhil_staging / src /signals /semantic_similarity_test.py
nsthorat's picture
Push
e4f9cbe
raw
history blame
No virus
2.73 kB
"""Test the semantic search signal."""
from typing import Iterable, cast
import numpy as np
import pytest
from pytest import approx
from pytest_mock import MockerFixture
from typing_extensions import override
from ..data.dataset_utils import lilac_embedding
from ..embeddings.vector_store import VectorStore
from ..schema import Item, RichData, VectorKey
from .semantic_similarity import SemanticSimilaritySignal
from .signal import TextEmbeddingSignal, clear_signal_registry, register_signal
EMBEDDINGS: dict[VectorKey, list[float]] = {
('1',): [1.0, 0.0, 0.0],
('2',): [0.9, 0.1, 0.0],
('3',): [0.0, 0.0, 1.0]
}
STR_EMBEDDINGS: dict[str, list[float]] = {
'hello': [1.0, 0.0, 0.0],
'hello world': [0.9, 0.1, 0.0],
'far': [0.0, 0.0, 1.0]
}
class TestVectorStore(VectorStore):
"""A test vector store with fixed embeddings."""
@override
def keys(self) -> list[VectorKey]:
return []
@override
def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
# We fix the vectors for the test vector store.
pass
@override
def get(self, keys: Iterable[VectorKey]) -> np.ndarray:
keys = keys or []
return np.array([EMBEDDINGS[row_id] for row_id in keys])
class TestEmbedding(TextEmbeddingSignal):
"""A test embed function."""
name = 'test_embedding'
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
"""Embed the examples, use a hashmap to the vector for simplicity."""
for example in data:
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
@pytest.fixture(scope='module', autouse=True)
def setup_teardown() -> Iterable[None]:
# Setup.
register_signal(TestEmbedding)
# Unit test runs.
yield
# Teardown.
clear_signal_registry()
def test_semantic_similarity_compute_keys(mocker: MockerFixture) -> None:
vector_store = TestVectorStore()
embed_mock = mocker.spy(TestEmbedding, 'compute')
signal = SemanticSimilaritySignal(query='hello', embedding=TestEmbedding.name)
scores = list(signal.vector_compute([('1',), ('2',), ('3',)], vector_store))
# Embeddings should be called only 1 time for the search.
assert embed_mock.call_count == 1
assert scores == [1.0, approx(0.938, 1e-3), approx(0.417, 1e-3)]
def test_semantic_similarity_compute_data(mocker: MockerFixture) -> None:
embed_mock = mocker.spy(TestEmbedding, 'compute')
signal = SemanticSimilaritySignal(query='hello', embedding=TestEmbedding.name)
# Compute over the text.
scores = list(signal.compute(STR_EMBEDDINGS.keys()))
# Embeddings should be called only 2 times, once for the search, once for the query itself.
assert embed_mock.call_count == 2
assert scores == [1.0, 0.9, 0.0]