nikhil_staging / src /data /dataset_compute_signal_chain_test.py
nsthorat's picture
Push
e4f9cbe
raw
history blame
No virus
7.47 kB
"""Tests for dataset.compute_signal() when signals are chained."""
import re
from typing import Iterable, List, Optional, cast
import numpy as np
import pytest
from pytest_mock import MockerFixture
from typing_extensions import override
from ..embeddings.vector_store import VectorStore
from ..schema import UUID_COLUMN, Field, Item, RichData, VectorKey, field, schema
from ..signals.signal import (
TextEmbeddingModelSignal,
TextEmbeddingSignal,
TextSignal,
TextSplitterSignal,
clear_signal_registry,
register_signal,
)
from .dataset import DatasetManifest
from .dataset_test_utils import (
TEST_DATASET_NAME,
TEST_NAMESPACE,
TestDataMaker,
enriched_embedding_span,
enriched_embedding_span_field,
enriched_item,
)
from .dataset_utils import lilac_embedding, lilac_span
SIMPLE_ITEMS: list[Item] = [{
UUID_COLUMN: '1',
'str': 'a',
'int': 1,
'bool': False,
'float': 3.0
}, {
UUID_COLUMN: '2',
'str': 'b',
'int': 2,
'bool': True,
'float': 2.0
}, {
UUID_COLUMN: '3',
'str': 'b',
'int': 2,
'bool': True,
'float': 1.0
}]
EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]),
('hello2.', [1.0, 1.0, 0.0]),
('hello world.', [1.0, 1.0, 1.0]),
('hello world2.', [2.0, 1.0, 1.0])]
STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
class TestSplitter(TextSplitterSignal):
"""Split documents into sentence by splitting on period."""
name = 'test_splitter'
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
for text in data:
if not isinstance(text, str):
raise ValueError(f'Expected text to be a string, got {type(text)} instead.')
sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence]
yield [
lilac_span(text.index(sentence),
text.index(sentence) + len(sentence)) for sentence in sentences
]
class TestEmbedding(TextEmbeddingSignal):
"""A test embed function."""
name = 'test_embedding'
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
"""Call the embedding function."""
for example in data:
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
class TestEmbeddingSumSignal(TextEmbeddingModelSignal):
"""Sums the embeddings to return a single floating point value."""
name = 'test_embedding_sum'
@override
def fields(self) -> Field:
return field('float32')
@override
def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
# The signal just sums the values of the embedding.
embedding_sums = vector_store.get(keys).sum(axis=1)
for embedding_sum in embedding_sums.tolist():
yield embedding_sum
@pytest.fixture(scope='module', autouse=True)
def setup_teardown() -> Iterable[None]:
# Setup.
register_signal(TestSplitter)
register_signal(TestEmbedding)
register_signal(TestEmbeddingSumSignal)
register_signal(NamedEntity)
# Unit test runs.
yield
# Teardown.
clear_signal_registry()
def test_manual_embedding_signal(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'hello.',
}, {
UUID_COLUMN: '2',
'text': 'hello2.',
}])
embed_mock = mocker.spy(TestEmbedding, 'compute')
embedding_signal = TestEmbedding()
dataset.compute_signal(embedding_signal, 'text')
embedding_sum_signal = TestEmbeddingSumSignal(embedding=TestEmbedding.name)
dataset.compute_signal(embedding_sum_signal, 'text')
# Make sure the embedding signal is not called twice.
assert embed_mock.call_count == 1
assert dataset.manifest() == DatasetManifest(
namespace=TEST_NAMESPACE,
dataset_name=TEST_DATASET_NAME,
data_schema=schema({
UUID_COLUMN: 'string',
'text': field(
'string',
fields={
'test_embedding': field(
signal=embedding_signal.dict(),
fields=[
enriched_embedding_span_field(
{'test_embedding_sum': field('float32', embedding_sum_signal.dict())})
])
}),
}),
num_items=2)
result = dataset.select_rows()
expected_result = [{
UUID_COLUMN: '1',
'text': enriched_item(
'hello.', {'test_embedding': [enriched_embedding_span(0, 6, {'test_embedding_sum': 1.0})]})
}, {
UUID_COLUMN: '2',
'text': enriched_item(
'hello2.', {'test_embedding': [enriched_embedding_span(0, 7, {'test_embedding_sum': 2.0})]})
}]
assert list(result) == expected_result
def test_auto_embedding_signal(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'hello.',
}, {
UUID_COLUMN: '2',
'text': 'hello2.',
}])
embed_mock = mocker.spy(TestEmbedding, 'compute')
# The embedding is automatically computed from the TestEmbeddingSumSignal.
embedding_sum_signal = TestEmbeddingSumSignal(embedding=TestEmbedding.name)
dataset.compute_signal(embedding_sum_signal, 'text')
# Make sure the embedding signal is not called twice.
assert embed_mock.call_count == 1
assert dataset.manifest() == DatasetManifest(
namespace=TEST_NAMESPACE,
dataset_name=TEST_DATASET_NAME,
data_schema=schema({
UUID_COLUMN: 'string',
'text': field(
'string',
fields={
'test_embedding': field(
signal=embedding_sum_signal._embedding_signal.dict(),
fields=[
enriched_embedding_span_field(
{'test_embedding_sum': field('float32', embedding_sum_signal.dict())})
])
}),
}),
num_items=2)
result = dataset.select_rows()
expected_result = [{
UUID_COLUMN: '1',
'text': enriched_item(
'hello.', {'test_embedding': [enriched_embedding_span(0, 6, {'test_embedding_sum': 1.0})]})
}, {
UUID_COLUMN: '2',
'text': enriched_item(
'hello2.', {'test_embedding': [enriched_embedding_span(0, 7, {'test_embedding_sum': 2.0})]})
}]
assert list(result) == expected_result
ENTITY_REGEX = r'[A-Za-z]+@[A-Za-z]+'
class NamedEntity(TextSignal):
"""Find special entities."""
name = 'entity'
@override
def fields(self) -> Field:
return field(fields=['string_span'])
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[List[Item]]]:
for text in data:
if not isinstance(text, str):
yield None
continue
yield [lilac_span(m.start(0), m.end(0)) for m in re.finditer(ENTITY_REGEX, text)]
def test_entity_on_split_signal(make_test_data: TestDataMaker) -> None:
text = 'Hello nik@test. Here are some other entities like pii@gmail and all@lilac.'
dataset = make_test_data([{UUID_COLUMN: '1', 'text': text}])
entity = NamedEntity()
dataset.compute_signal(TestSplitter(), 'text')
dataset.compute_signal(entity, ('text', 'test_splitter', '*'))
result = dataset.select_rows(['text'])
assert list(result) == [{
UUID_COLUMN: '1',
'text': enriched_item(
text, {
'test_splitter': [
lilac_span(0, 15, {'entity': [lilac_span(6, 14)]}),
lilac_span(16, 74, {'entity': [
lilac_span(50, 59),
lilac_span(64, 73),
]}),
]
})
}]