Spaces:
Runtime error
Runtime error
File size: 4,013 Bytes
e4f9cbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""Test signal base class."""
from typing import Iterable, Optional
import pytest
from typing_extensions import override
from ..embeddings.vector_store import VectorStore
from ..schema import Field, Item, RichData, SignalInputType, VectorKey, field
from .signal import (
Signal,
TextEmbeddingModelSignal,
TextEmbeddingSignal,
TextSplitterSignal,
clear_signal_registry,
get_signal_by_type,
get_signal_cls,
get_signals_by_type,
register_signal,
resolve_signal,
)
class TestSignal(Signal):
"""A test signal."""
# Pydantic fields
name = 'test_signal'
input_type = SignalInputType.TEXT
query: str
@override
def fields(self) -> Field:
return field('float32')
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
del data
return []
class TestTextSplitter(TextSplitterSignal):
"""A test text splitter."""
name = 'test_splitter'
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
del data
return []
class TestTextEmbedding(TextEmbeddingSignal):
"""A test text embedding."""
name = 'test_embedding'
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
del data
return []
class TestTextEmbeddingModelSignal(TextEmbeddingModelSignal):
"""A test text embedding model."""
name = 'test_embedding_model'
@override
def fields(self) -> Field:
return field('float32')
@override
def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
# The signal just sums the values of the embedding.
del keys, vector_store
return []
@pytest.fixture(scope='module', autouse=True)
def setup_teardown() -> Iterable[None]:
# Setup.
register_signal(TestSignal)
register_signal(TestTextSplitter)
register_signal(TestTextEmbedding)
register_signal(TestTextEmbeddingModelSignal)
# Unit test runs.
yield
# Teardown.
clear_signal_registry()
def test_signal_serialization() -> None:
signal = TestSignal(query='test')
# The class variables should not be included.
assert signal.dict() == {'signal_name': 'test_signal', 'query': 'test'}
def test_get_signal_cls() -> None:
"""Test getting a signal."""
assert TestSignal == get_signal_cls('test_signal')
def test_resolve_signal() -> None:
"""Test resolving a signal."""
test_signal = TestSignal(query='hello')
# Signals pass through.
assert resolve_signal(test_signal) == test_signal
# Dicts resolve to the base class.
assert resolve_signal(test_signal.dict()) == test_signal
def test_get_signal_by_type() -> None:
assert get_signal_by_type(TestTextSplitter.name, TextSplitterSignal) == TestTextSplitter
assert get_signal_by_type(TestTextEmbedding.name, TextEmbeddingSignal) == TestTextEmbedding
def test_get_signal_by_type_validation() -> None:
with pytest.raises(ValueError, match='Signal "invalid_signal" not found in the registry'):
get_signal_by_type('invalid_signal', TextSplitterSignal)
with pytest.raises(
ValueError, match=f'"{TestTextSplitter.name}" is a `{TestTextSplitter.__name__}`'):
get_signal_by_type(TestTextSplitter.name, TextEmbeddingSignal)
def test_get_signals_by_type() -> None:
assert get_signals_by_type(TextSplitterSignal) == [TestTextSplitter]
assert get_signals_by_type(TextEmbeddingSignal) == [TestTextEmbedding]
def test_signal_type_enum() -> None:
model_signal = TestTextEmbeddingModelSignal(embedding='test_embedding')
schema_properties = model_signal.schema()['properties']
assert schema_properties['embedding']['enum'] == [TestTextEmbedding.name]
class TestSignalNoDisplayName(Signal):
name = 'signal_no_name'
class TestSignalDisplayName(Signal):
name = 'signal_display_name'
display_name = 'test display name'
def test_signal_title_schema() -> None:
assert TestSignalNoDisplayName.schema()['title'] == TestSignalNoDisplayName.__name__
assert TestSignalDisplayName.schema()['title'] == 'test display name'
|