File size: 4,013 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Test signal base class."""
from typing import Iterable, Optional

import pytest
from typing_extensions import override

from ..embeddings.vector_store import VectorStore
from ..schema import Field, Item, RichData, SignalInputType, VectorKey, field
from .signal import (
  Signal,
  TextEmbeddingModelSignal,
  TextEmbeddingSignal,
  TextSplitterSignal,
  clear_signal_registry,
  get_signal_by_type,
  get_signal_cls,
  get_signals_by_type,
  register_signal,
  resolve_signal,
)


class TestSignal(Signal):
  """A test signal."""

  # Pydantic fields
  name = 'test_signal'
  input_type = SignalInputType.TEXT

  query: str

  @override
  def fields(self) -> Field:
    return field('float32')

  @override
  def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
    del data
    return []


class TestTextSplitter(TextSplitterSignal):
  """A test text splitter."""
  name = 'test_splitter'

  @override
  def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
    del data
    return []


class TestTextEmbedding(TextEmbeddingSignal):
  """A test text embedding."""
  name = 'test_embedding'

  @override
  def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
    del data
    return []


class TestTextEmbeddingModelSignal(TextEmbeddingModelSignal):
  """A test text embedding model."""
  name = 'test_embedding_model'

  @override
  def fields(self) -> Field:
    return field('float32')

  @override
  def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
    # The signal just sums the values of the embedding.
    del keys, vector_store
    return []


@pytest.fixture(scope='module', autouse=True)
def setup_teardown() -> Iterable[None]:
  # Setup.
  register_signal(TestSignal)
  register_signal(TestTextSplitter)
  register_signal(TestTextEmbedding)
  register_signal(TestTextEmbeddingModelSignal)

  # Unit test runs.
  yield

  # Teardown.
  clear_signal_registry()


def test_signal_serialization() -> None:
  signal = TestSignal(query='test')

  # The class variables should not be included.
  assert signal.dict() == {'signal_name': 'test_signal', 'query': 'test'}


def test_get_signal_cls() -> None:
  """Test getting a signal."""
  assert TestSignal == get_signal_cls('test_signal')


def test_resolve_signal() -> None:
  """Test resolving a signal."""
  test_signal = TestSignal(query='hello')

  # Signals pass through.
  assert resolve_signal(test_signal) == test_signal

  # Dicts resolve to the base class.
  assert resolve_signal(test_signal.dict()) == test_signal


def test_get_signal_by_type() -> None:
  assert get_signal_by_type(TestTextSplitter.name, TextSplitterSignal) == TestTextSplitter
  assert get_signal_by_type(TestTextEmbedding.name, TextEmbeddingSignal) == TestTextEmbedding


def test_get_signal_by_type_validation() -> None:
  with pytest.raises(ValueError, match='Signal "invalid_signal" not found in the registry'):
    get_signal_by_type('invalid_signal', TextSplitterSignal)

  with pytest.raises(
      ValueError, match=f'"{TestTextSplitter.name}" is a `{TestTextSplitter.__name__}`'):
    get_signal_by_type(TestTextSplitter.name, TextEmbeddingSignal)


def test_get_signals_by_type() -> None:
  assert get_signals_by_type(TextSplitterSignal) == [TestTextSplitter]
  assert get_signals_by_type(TextEmbeddingSignal) == [TestTextEmbedding]


def test_signal_type_enum() -> None:
  model_signal = TestTextEmbeddingModelSignal(embedding='test_embedding')
  schema_properties = model_signal.schema()['properties']
  assert schema_properties['embedding']['enum'] == [TestTextEmbedding.name]


class TestSignalNoDisplayName(Signal):
  name = 'signal_no_name'


class TestSignalDisplayName(Signal):
  name = 'signal_display_name'
  display_name = 'test display name'


def test_signal_title_schema() -> None:
  assert TestSignalNoDisplayName.schema()['title'] == TestSignalNoDisplayName.__name__
  assert TestSignalDisplayName.schema()['title'] == 'test display name'