Spaces:
Runtime error
Runtime error
"""Tests for dataset.compute_signal() when signals are chained.""" | |
import re | |
from typing import Iterable, List, Optional, cast | |
import numpy as np | |
import pytest | |
from pytest_mock import MockerFixture | |
from typing_extensions import override | |
from ..embeddings.vector_store import VectorStore | |
from ..schema import UUID_COLUMN, Field, Item, RichData, VectorKey, field, schema | |
from ..signals.signal import ( | |
TextEmbeddingModelSignal, | |
TextEmbeddingSignal, | |
TextSignal, | |
TextSplitterSignal, | |
clear_signal_registry, | |
register_signal, | |
) | |
from .dataset import DatasetManifest | |
from .dataset_test_utils import ( | |
TEST_DATASET_NAME, | |
TEST_NAMESPACE, | |
TestDataMaker, | |
enriched_embedding_span, | |
enriched_embedding_span_field, | |
enriched_item, | |
) | |
from .dataset_utils import lilac_embedding, lilac_span | |
SIMPLE_ITEMS: list[Item] = [{ | |
UUID_COLUMN: '1', | |
'str': 'a', | |
'int': 1, | |
'bool': False, | |
'float': 3.0 | |
}, { | |
UUID_COLUMN: '2', | |
'str': 'b', | |
'int': 2, | |
'bool': True, | |
'float': 2.0 | |
}, { | |
UUID_COLUMN: '3', | |
'str': 'b', | |
'int': 2, | |
'bool': True, | |
'float': 1.0 | |
}] | |
EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]), | |
('hello2.', [1.0, 1.0, 0.0]), | |
('hello world.', [1.0, 1.0, 1.0]), | |
('hello world2.', [2.0, 1.0, 1.0])] | |
STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} | |
class TestSplitter(TextSplitterSignal): | |
"""Split documents into sentence by splitting on period.""" | |
name = 'test_splitter' | |
def compute(self, data: Iterable[RichData]) -> Iterable[Item]: | |
for text in data: | |
if not isinstance(text, str): | |
raise ValueError(f'Expected text to be a string, got {type(text)} instead.') | |
sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence] | |
yield [ | |
lilac_span(text.index(sentence), | |
text.index(sentence) + len(sentence)) for sentence in sentences | |
] | |
class TestEmbedding(TextEmbeddingSignal): | |
"""A test embed function.""" | |
name = 'test_embedding' | |
def compute(self, data: Iterable[RichData]) -> Iterable[Item]: | |
"""Call the embedding function.""" | |
for example in data: | |
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))] | |
class TestEmbeddingSumSignal(TextEmbeddingModelSignal): | |
"""Sums the embeddings to return a single floating point value.""" | |
name = 'test_embedding_sum' | |
def fields(self) -> Field: | |
return field('float32') | |
def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]: | |
# The signal just sums the values of the embedding. | |
embedding_sums = vector_store.get(keys).sum(axis=1) | |
for embedding_sum in embedding_sums.tolist(): | |
yield embedding_sum | |
def setup_teardown() -> Iterable[None]: | |
# Setup. | |
register_signal(TestSplitter) | |
register_signal(TestEmbedding) | |
register_signal(TestEmbeddingSumSignal) | |
register_signal(NamedEntity) | |
# Unit test runs. | |
yield | |
# Teardown. | |
clear_signal_registry() | |
def test_manual_embedding_signal(make_test_data: TestDataMaker, mocker: MockerFixture) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'text': 'hello.', | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'hello2.', | |
}]) | |
embed_mock = mocker.spy(TestEmbedding, 'compute') | |
embedding_signal = TestEmbedding() | |
dataset.compute_signal(embedding_signal, 'text') | |
embedding_sum_signal = TestEmbeddingSumSignal(embedding=TestEmbedding.name) | |
dataset.compute_signal(embedding_sum_signal, 'text') | |
# Make sure the embedding signal is not called twice. | |
assert embed_mock.call_count == 1 | |
assert dataset.manifest() == DatasetManifest( | |
namespace=TEST_NAMESPACE, | |
dataset_name=TEST_DATASET_NAME, | |
data_schema=schema({ | |
UUID_COLUMN: 'string', | |
'text': field( | |
'string', | |
fields={ | |
'test_embedding': field( | |
signal=embedding_signal.dict(), | |
fields=[ | |
enriched_embedding_span_field( | |
{'test_embedding_sum': field('float32', embedding_sum_signal.dict())}) | |
]) | |
}), | |
}), | |
num_items=2) | |
result = dataset.select_rows() | |
expected_result = [{ | |
UUID_COLUMN: '1', | |
'text': enriched_item( | |
'hello.', {'test_embedding': [enriched_embedding_span(0, 6, {'test_embedding_sum': 1.0})]}) | |
}, { | |
UUID_COLUMN: '2', | |
'text': enriched_item( | |
'hello2.', {'test_embedding': [enriched_embedding_span(0, 7, {'test_embedding_sum': 2.0})]}) | |
}] | |
assert list(result) == expected_result | |
def test_auto_embedding_signal(make_test_data: TestDataMaker, mocker: MockerFixture) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'text': 'hello.', | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'hello2.', | |
}]) | |
embed_mock = mocker.spy(TestEmbedding, 'compute') | |
# The embedding is automatically computed from the TestEmbeddingSumSignal. | |
embedding_sum_signal = TestEmbeddingSumSignal(embedding=TestEmbedding.name) | |
dataset.compute_signal(embedding_sum_signal, 'text') | |
# Make sure the embedding signal is not called twice. | |
assert embed_mock.call_count == 1 | |
assert dataset.manifest() == DatasetManifest( | |
namespace=TEST_NAMESPACE, | |
dataset_name=TEST_DATASET_NAME, | |
data_schema=schema({ | |
UUID_COLUMN: 'string', | |
'text': field( | |
'string', | |
fields={ | |
'test_embedding': field( | |
signal=embedding_sum_signal._embedding_signal.dict(), | |
fields=[ | |
enriched_embedding_span_field( | |
{'test_embedding_sum': field('float32', embedding_sum_signal.dict())}) | |
]) | |
}), | |
}), | |
num_items=2) | |
result = dataset.select_rows() | |
expected_result = [{ | |
UUID_COLUMN: '1', | |
'text': enriched_item( | |
'hello.', {'test_embedding': [enriched_embedding_span(0, 6, {'test_embedding_sum': 1.0})]}) | |
}, { | |
UUID_COLUMN: '2', | |
'text': enriched_item( | |
'hello2.', {'test_embedding': [enriched_embedding_span(0, 7, {'test_embedding_sum': 2.0})]}) | |
}] | |
assert list(result) == expected_result | |
ENTITY_REGEX = r'[A-Za-z]+@[A-Za-z]+' | |
class NamedEntity(TextSignal): | |
"""Find special entities.""" | |
name = 'entity' | |
def fields(self) -> Field: | |
return field(fields=['string_span']) | |
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[List[Item]]]: | |
for text in data: | |
if not isinstance(text, str): | |
yield None | |
continue | |
yield [lilac_span(m.start(0), m.end(0)) for m in re.finditer(ENTITY_REGEX, text)] | |
def test_entity_on_split_signal(make_test_data: TestDataMaker) -> None: | |
text = 'Hello nik@test. Here are some other entities like pii@gmail and all@lilac.' | |
dataset = make_test_data([{UUID_COLUMN: '1', 'text': text}]) | |
entity = NamedEntity() | |
dataset.compute_signal(TestSplitter(), 'text') | |
dataset.compute_signal(entity, ('text', 'test_splitter', '*')) | |
result = dataset.select_rows(['text']) | |
assert list(result) == [{ | |
UUID_COLUMN: '1', | |
'text': enriched_item( | |
text, { | |
'test_splitter': [ | |
lilac_span(0, 15, {'entity': [lilac_span(6, 14)]}), | |
lilac_span(16, 74, {'entity': [ | |
lilac_span(50, 59), | |
lilac_span(64, 73), | |
]}), | |
] | |
}) | |
}] | |