nikhil_staging / src /signals /signal_test.py
nsthorat's picture
Push
e4f9cbe
raw
history blame
No virus
4.01 kB
"""Test signal base class."""
from typing import Iterable, Optional
import pytest
from typing_extensions import override
from ..embeddings.vector_store import VectorStore
from ..schema import Field, Item, RichData, SignalInputType, VectorKey, field
from .signal import (
Signal,
TextEmbeddingModelSignal,
TextEmbeddingSignal,
TextSplitterSignal,
clear_signal_registry,
get_signal_by_type,
get_signal_cls,
get_signals_by_type,
register_signal,
resolve_signal,
)
class TestSignal(Signal):
"""A test signal."""
# Pydantic fields
name = 'test_signal'
input_type = SignalInputType.TEXT
query: str
@override
def fields(self) -> Field:
return field('float32')
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
del data
return []
class TestTextSplitter(TextSplitterSignal):
"""A test text splitter."""
name = 'test_splitter'
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
del data
return []
class TestTextEmbedding(TextEmbeddingSignal):
"""A test text embedding."""
name = 'test_embedding'
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
del data
return []
class TestTextEmbeddingModelSignal(TextEmbeddingModelSignal):
"""A test text embedding model."""
name = 'test_embedding_model'
@override
def fields(self) -> Field:
return field('float32')
@override
def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
# The signal just sums the values of the embedding.
del keys, vector_store
return []
@pytest.fixture(scope='module', autouse=True)
def setup_teardown() -> Iterable[None]:
# Setup.
register_signal(TestSignal)
register_signal(TestTextSplitter)
register_signal(TestTextEmbedding)
register_signal(TestTextEmbeddingModelSignal)
# Unit test runs.
yield
# Teardown.
clear_signal_registry()
def test_signal_serialization() -> None:
signal = TestSignal(query='test')
# The class variables should not be included.
assert signal.dict() == {'signal_name': 'test_signal', 'query': 'test'}
def test_get_signal_cls() -> None:
"""Test getting a signal."""
assert TestSignal == get_signal_cls('test_signal')
def test_resolve_signal() -> None:
"""Test resolving a signal."""
test_signal = TestSignal(query='hello')
# Signals pass through.
assert resolve_signal(test_signal) == test_signal
# Dicts resolve to the base class.
assert resolve_signal(test_signal.dict()) == test_signal
def test_get_signal_by_type() -> None:
assert get_signal_by_type(TestTextSplitter.name, TextSplitterSignal) == TestTextSplitter
assert get_signal_by_type(TestTextEmbedding.name, TextEmbeddingSignal) == TestTextEmbedding
def test_get_signal_by_type_validation() -> None:
with pytest.raises(ValueError, match='Signal "invalid_signal" not found in the registry'):
get_signal_by_type('invalid_signal', TextSplitterSignal)
with pytest.raises(
ValueError, match=f'"{TestTextSplitter.name}" is a `{TestTextSplitter.__name__}`'):
get_signal_by_type(TestTextSplitter.name, TextEmbeddingSignal)
def test_get_signals_by_type() -> None:
assert get_signals_by_type(TextSplitterSignal) == [TestTextSplitter]
assert get_signals_by_type(TextEmbeddingSignal) == [TestTextEmbedding]
def test_signal_type_enum() -> None:
model_signal = TestTextEmbeddingModelSignal(embedding='test_embedding')
schema_properties = model_signal.schema()['properties']
assert schema_properties['embedding']['enum'] == [TestTextEmbedding.name]
class TestSignalNoDisplayName(Signal):
name = 'signal_no_name'
class TestSignalDisplayName(Signal):
name = 'signal_display_name'
display_name = 'test display name'
def test_signal_title_schema() -> None:
assert TestSignalNoDisplayName.schema()['title'] == TestSignalNoDisplayName.__name__
assert TestSignalDisplayName.schema()['title'] == 'test display name'