nikhil_staging / src /data /dataset_select_rows_udf_test.py
nsthorat's picture
Push
e4f9cbe
raw
history blame
No virus
11.8 kB
"""Tests for dataset.select_rows(udf_col)."""
from typing import Iterable, Optional, cast
import numpy as np
import pytest
from typing_extensions import override
from ..embeddings.vector_store import VectorStore
from ..schema import UUID_COLUMN, VALUE_KEY, Field, Item, RichData, VectorKey, field
from ..signals.signal import (
TextEmbeddingModelSignal,
TextEmbeddingSignal,
TextSignal,
TextSplitterSignal,
clear_signal_registry,
register_signal,
)
from .dataset import BinaryFilterTuple, BinaryOp, Column, val
from .dataset_test_utils import TestDataMaker, enriched_item
from .dataset_utils import lilac_embedding, lilac_span
EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]),
('hello2.', [1.0, 1.0, 0.0]),
('hello world.', [1.0, 1.0, 1.0]),
('hello world2.', [2.0, 1.0, 1.0])]
STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
class TestEmbedding(TextEmbeddingSignal):
"""A test embed function."""
name = 'test_embedding'
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
"""Call the embedding function."""
for example in data:
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
class LengthSignal(TextSignal):
name = 'length_signal'
_call_count: int = 0
def fields(self) -> Field:
return field('int32')
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
for text_content in data:
self._call_count += 1
yield len(text_content)
class TestSignal(TextSignal):
name = 'test_signal'
@override
def fields(self) -> Field:
return field(fields={'len': 'int32', 'flen': 'float32'})
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data]
class TestEmbeddingSumSignal(TextEmbeddingModelSignal):
"""Sums the embeddings to return a single floating point value."""
name = 'test_embedding_sum'
@override
def fields(self) -> Field:
return field('float32')
@override
def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
# The signal just sums the values of the embedding.
embedding_sums = vector_store.get(keys).sum(axis=1)
for embedding_sum in embedding_sums.tolist():
yield embedding_sum
class ComputedKeySignal(TextSignal):
name = 'computed_key'
@override
def fields(self) -> Field:
return field('int64')
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
for text in data:
yield 1
def key(self, is_computed_signal: Optional[bool] = False) -> str:
return f'key_{is_computed_signal}'
@pytest.fixture(scope='module', autouse=True)
def setup_teardown() -> Iterable[None]:
# Setup.
register_signal(LengthSignal)
register_signal(TestSplitter)
register_signal(TestEmbedding)
register_signal(TestSignal)
register_signal(TestEmbeddingSumSignal)
register_signal(ComputedKeySignal)
# Unit test runs.
yield
# Teardown.
clear_signal_registry()
def test_udf(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'hello'
}, {
UUID_COLUMN: '2',
'text': 'everybody'
}])
signal_col = Column('text', signal_udf=TestSignal())
result = dataset.select_rows(['text', signal_col])
assert list(result) == [{
UUID_COLUMN: '1',
'text': 'hello',
'test_signal(text)': {
'len': 5,
'flen': 5.0
}
}, {
UUID_COLUMN: '2',
'text': 'everybody',
'test_signal(text)': {
'len': 9,
'flen': 9.0
}
}]
def test_udf_with_filters(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'hello'
}, {
UUID_COLUMN: '2',
'text': 'everybody'
}])
signal_col = Column('text', signal_udf=TestSignal())
# Filter by source feature.
filters: list[BinaryFilterTuple] = [('text', BinaryOp.EQUALS, 'everybody')]
result = dataset.select_rows(['text', signal_col], filters=filters)
assert list(result) == [{
UUID_COLUMN: '2',
'text': 'everybody',
'test_signal(text)': {
'len': 9,
'flen': 9.0
}
}]
def test_udf_with_uuid_filter(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'hello'
}, {
UUID_COLUMN: '2',
'text': 'everybody'
}])
# Filter by a specific UUID.
filters: list[BinaryFilterTuple] = [(UUID_COLUMN, BinaryOp.EQUALS, '1')]
udf_col = Column('text', signal_udf=LengthSignal())
result = dataset.select_rows(['text', udf_col], filters=filters)
assert list(result) == [{UUID_COLUMN: '1', 'text': 'hello', 'length_signal(text)': 5}]
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 1
filters = [(UUID_COLUMN, BinaryOp.EQUALS, '2')]
result = dataset.select_rows(['text', udf_col], filters=filters)
assert list(result) == [{UUID_COLUMN: '2', 'text': 'everybody', 'length_signal(text)': 9}]
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 1 + 1
# No filters.
result = dataset.select_rows(['text', udf_col])
assert list(result) == [{
UUID_COLUMN: '1',
'text': 'hello',
'length_signal(text)': 5
}, {
UUID_COLUMN: '2',
'text': 'everybody',
'length_signal(text)': 9
}]
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 + 2
def test_udf_with_uuid_filter_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': ['hello', 'hi']
}, {
UUID_COLUMN: '2',
'text': ['everybody', 'bye', 'test']
}])
# Filter by a specific UUID.
filters: list[BinaryFilterTuple] = [(UUID_COLUMN, BinaryOp.EQUALS, '1')]
udf_col = Column(('text', '*'), signal_udf=LengthSignal())
result = dataset.select_rows(['text', udf_col], filters=filters)
assert list(result) == [{
UUID_COLUMN: '1',
'text': ['hello', 'hi'],
'length_signal(text)': [5, 2]
}]
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2
# Filter by a specific UUID.
filters = [(UUID_COLUMN, BinaryOp.EQUALS, '2')]
result = dataset.select_rows(['text', udf_col], filters=filters)
assert list(result) == [{
UUID_COLUMN: '2',
'text': ['everybody', 'bye', 'test'],
'length_signal(text)': [9, 3, 4]
}]
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 + 3
def test_udf_deeply_nested(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': [['hello'], ['hi', 'bye']]
}, {
UUID_COLUMN: '2',
'text': [['everybody', 'bye'], ['test']]
}])
udf_col = Column(('text', '*', '*'), signal_udf=LengthSignal())
result = dataset.select_rows([udf_col])
assert list(result) == [{
UUID_COLUMN: '1',
'length_signal(text.*)': [[5], [2, 3]]
}, {
UUID_COLUMN: '2',
'length_signal(text.*)': [[9, 3], [4]]
}]
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 6
def test_udf_with_embedding(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'hello.',
}, {
UUID_COLUMN: '2',
'text': 'hello2.',
}])
dataset.compute_signal(TestEmbedding(), 'text')
signal_col = Column('text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'))
result = dataset.select_rows([val('text'), signal_col])
expected_result: list[Item] = [{
UUID_COLUMN: '1',
f'text.{VALUE_KEY}': 'hello.',
'test_embedding_sum(text.test_embedding.*.embedding)': [1.0]
}, {
UUID_COLUMN: '2',
f'text.{VALUE_KEY}': 'hello2.',
'test_embedding_sum(text.test_embedding.*.embedding)': [2.0]
}]
assert list(result) == expected_result
# Select rows with alias.
signal_col = Column(
'text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'), alias='emb_sum')
result = dataset.select_rows([val('text'), signal_col])
expected_result = [{
UUID_COLUMN: '1',
f'text.{VALUE_KEY}': 'hello.',
'emb_sum': [1.0]
}, {
UUID_COLUMN: '2',
f'text.{VALUE_KEY}': 'hello2.',
'emb_sum': [2.0]
}]
assert list(result) == expected_result
def test_udf_with_nested_embedding(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': ['hello.', 'hello world.'],
}, {
UUID_COLUMN: '2',
'text': ['hello world2.', 'hello2.'],
}])
dataset.compute_signal(TestEmbedding(), ('text', '*'))
signal_col = Column(('text', '*'), signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'))
result = dataset.select_rows([val(('text', '*')), signal_col])
expected_result = [{
UUID_COLUMN: '1',
f'text.*.{VALUE_KEY}': ['hello.', 'hello world.'],
'test_embedding_sum(text.*.test_embedding.*.embedding)': [[1.0], [3.0]]
}, {
UUID_COLUMN: '2',
f'text.*.{VALUE_KEY}': ['hello world2.', 'hello2.'],
'test_embedding_sum(text.*.test_embedding.*.embedding)': [[4.0], [2.0]]
}]
assert list(result) == expected_result
def test_udf_throws_without_precomputing(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'hello.',
}, {
UUID_COLUMN: '2',
'text': 'hello2.',
}])
# Embedding is not precomputed, yet we ask for the embedding.
signal_col = Column('text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'))
with pytest.raises(ValueError, match='Embedding signal "test_embedding" is not computed'):
dataset.select_rows([val('text'), signal_col])
class TestSplitter(TextSplitterSignal):
"""Split documents into sentence by splitting on period."""
name = 'test_splitter'
@override
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
for text in data:
if not isinstance(text, str):
raise ValueError(f'Expected text to be a string, got {type(text)} instead.')
result: list[Item] = []
for sentence in text.split('.'):
start = text.index(sentence)
end = start + len(sentence)
result.append(lilac_span(start, end))
yield result
def test_udf_after_precomputed_split(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'sentence 1. sentence 2 is longer',
}, {
UUID_COLUMN: '2',
'text': 'sentence 1 is longer. sent2 is short',
}])
dataset.compute_signal(TestSplitter(), 'text')
udf = Column('text', signal_udf=LengthSignal())
result = dataset.select_rows(['*', udf], combine_columns=True)
assert list(result) == [{
UUID_COLUMN: '1',
'text': enriched_item('sentence 1. sentence 2 is longer', {
'length_signal': 32,
'test_splitter': [lilac_span(0, 10), lilac_span(11, 32)]
})
}, {
UUID_COLUMN: '2',
'text': enriched_item('sentence 1 is longer. sent2 is short', {
'length_signal': 36,
'test_splitter': [lilac_span(0, 20), lilac_span(21, 36)]
})
}]
def test_is_computed_signal_key(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'hello.',
}, {
UUID_COLUMN: '2',
'text': 'hello2.',
}])
signal_col = Column('text', signal_udf=ComputedKeySignal())
# Filter by source feature.
filters: list[BinaryFilterTuple] = [('text', BinaryOp.EQUALS, 'everybody')]
result = dataset.select_rows(['text', signal_col])
assert list(result) == [{
UUID_COLUMN: '1',
'text': 'hello.',
'key_False(text)': 1
}, {
UUID_COLUMN: '2',
'text': 'hello2.',
'key_False(text)': 1
}]