mediastorm / tests /test_embedder_cache.py
remdms's picture
feat: add LRU-cached embed_query to avoid re-embedding repeated searches
2bf4d05
import pytest
from mediastorm.vectorize.embedder import Embedder
@pytest.fixture(scope="module")
def embedder():
return Embedder()
def test_embed_query_returns_384_dimensions(embedder):
"""embed_query should produce a 384-dimensional vector."""
vec = embedder.embed_query("documentary about war")
assert len(vec) == 384
def test_embed_query_is_normalized(embedder):
"""embed_query result should be L2-normalized."""
vec = embedder.embed_query("human rights in Africa")
norm = sum(x ** 2 for x in vec) ** 0.5
assert norm == pytest.approx(1.0, abs=1e-4)
def test_embed_query_repeated_calls_return_identical_results(embedder):
"""Repeated calls with same text should return identical vectors."""
text = "award winning photography"
v1 = embedder.embed_query(text)
v2 = embedder.embed_query(text)
assert v1 == v2
def test_embed_query_cache_avoids_recomputation(embedder):
"""Second call with same text should not invoke embed_texts again."""
# Clear the cache so we start fresh
embedder._embed_query_cached.cache_clear()
call_count = 0
original_embed_texts = embedder.embed_texts
def counting_embed_texts(texts):
nonlocal call_count
call_count += 1
return original_embed_texts(texts)
embedder.embed_texts = counting_embed_texts
try:
embedder.embed_query("cache test query")
embedder.embed_query("cache test query")
finally:
embedder.embed_texts = original_embed_texts
assert call_count == 1, f"embed_texts called {call_count} times, expected 1"
def test_embed_query_different_texts_produce_different_vectors(embedder):
"""Different queries should produce different embeddings."""
v1 = embedder.embed_query("war documentary")
v2 = embedder.embed_query("cooking show")
assert v1 != v2
def test_embed_query_matches_embed_texts_output(embedder):
"""embed_query result should match embed_texts([text])[0]."""
text = "journalism and press freedom"
# Use embed_texts directly to bypass cache for comparison
expected = embedder.embed_texts([text])[0]
result = embedder.embed_query(text)
assert result == pytest.approx(expected, abs=1e-6)