Spaces:
Runtime error
Runtime error
"""Tests for dataset.select_rows(udf_col).""" | |
from typing import Iterable, Optional, cast | |
import numpy as np | |
import pytest | |
from typing_extensions import override | |
from ..embeddings.vector_store import VectorDBIndex | |
from ..schema import ( | |
ROWID, | |
Field, | |
Item, | |
RichData, | |
SignalInputType, | |
VectorKey, | |
field, | |
lilac_embedding, | |
lilac_span, | |
) | |
from ..signal import ( | |
TextEmbeddingSignal, | |
TextSignal, | |
TextSplitterSignal, | |
VectorSignal, | |
clear_signal_registry, | |
register_signal, | |
) | |
from .dataset import BinaryFilterTuple, Column | |
from .dataset_test_utils import TestDataMaker, enriched_item | |
EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]), | |
('hello2.', [1.0, 1.0, 0.0]), | |
('hello world.', [1.0, 1.0, 1.0]), | |
('hello world2.', [2.0, 1.0, 1.0])] | |
STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} | |
class TestEmbedding(TextEmbeddingSignal): | |
"""A test embed function.""" | |
name = 'test_embedding' | |
def compute(self, data: Iterable[RichData]) -> Iterable[Item]: | |
"""Call the embedding function.""" | |
for example in data: | |
yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))] | |
class LengthSignal(TextSignal): | |
name = 'length_signal' | |
_call_count: int = 0 | |
def fields(self) -> Field: | |
return field('int32') | |
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: | |
for text_content in data: | |
self._call_count += 1 | |
yield len(text_content) | |
class TestSignal(TextSignal): | |
name = 'test_signal' | |
def fields(self) -> Field: | |
return field(fields={'len': 'int32', 'flen': 'float32'}) | |
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: | |
return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data] | |
class TestEmbeddingSumSignal(VectorSignal): | |
"""Sums the embeddings to return a single floating point value.""" | |
name = 'test_embedding_sum' | |
input_type = SignalInputType.TEXT | |
def fields(self) -> Field: | |
return field('float32') | |
def vector_compute(self, keys: Iterable[VectorKey], | |
vector_index: VectorDBIndex) -> Iterable[Item]: | |
# The signal just sums the values of the embedding. | |
all_vector_spans = vector_index.get(keys) | |
for vector_spans in all_vector_spans: | |
yield vector_spans[0]['vector'].sum() | |
class ComputedKeySignal(TextSignal): | |
name = 'computed_key' | |
def fields(self) -> Field: | |
return field('int64') | |
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: | |
for text in data: | |
yield 1 | |
def key(self, is_computed_signal: Optional[bool] = False) -> str: | |
return f'key_{is_computed_signal}' | |
def setup_teardown() -> Iterable[None]: | |
# Setup. | |
register_signal(LengthSignal) | |
register_signal(TestSplitter) | |
register_signal(TestEmbedding) | |
register_signal(TestSignal) | |
register_signal(TestEmbeddingSumSignal) | |
register_signal(ComputedKeySignal) | |
# Unit test runs. | |
yield | |
# Teardown. | |
clear_signal_registry() | |
def test_udf(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) | |
signal_col = Column('text', signal_udf=TestSignal()) | |
result = dataset.select_rows(['text', signal_col]) | |
assert list(result) == [{ | |
'text': 'hello', | |
'text.test_signal': { | |
'len': 5, | |
'flen': 5.0 | |
} | |
}, { | |
'text': 'everybody', | |
'text.test_signal': { | |
'len': 9, | |
'flen': 9.0 | |
} | |
}] | |
def test_udf_with_filters(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) | |
signal_col = Column('text', signal_udf=TestSignal()) | |
# Filter by source feature. | |
filters: list[BinaryFilterTuple] = [('text', 'equals', 'everybody')] | |
result = dataset.select_rows(['text', signal_col], filters=filters) | |
assert list(result) == [{'text': 'everybody', 'text.test_signal': {'len': 9, 'flen': 9.0}}] | |
def test_udf_with_rowid_filter(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) | |
# Filter by a specific rowid. | |
filters: list[BinaryFilterTuple] = [(ROWID, 'equals', '1')] | |
udf_col = Column('text', signal_udf=LengthSignal()) | |
result = dataset.select_rows([ROWID, 'text', udf_col], filters=filters) | |
assert list(result) == [{ROWID: '1', 'text': 'hello', 'text.length_signal': 5}] | |
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 1 | |
filters = [(ROWID, 'equals', '2')] | |
result = dataset.select_rows([ROWID, 'text', udf_col], filters=filters) | |
assert list(result) == [{ROWID: '2', 'text': 'everybody', 'text.length_signal': 9}] | |
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 1 + 1 | |
# No filters. | |
result = dataset.select_rows([ROWID, 'text', udf_col]) | |
assert list(result) == [{ | |
ROWID: '1', | |
'text': 'hello', | |
'text.length_signal': 5 | |
}, { | |
ROWID: '2', | |
'text': 'everybody', | |
'text.length_signal': 9 | |
}] | |
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 + 2 | |
def test_udf_with_rowid_filter_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{'text': ['hello', 'hi']}, {'text': ['everybody', 'bye', 'test']}]) | |
# Filter by a specific rowid. | |
filters: list[BinaryFilterTuple] = [(ROWID, 'equals', '1')] | |
udf_col = Column(('text', '*'), signal_udf=LengthSignal()) | |
result = dataset.select_rows([ROWID, 'text', udf_col], filters=filters) | |
assert list(result) == [{ROWID: '1', 'text': ['hello', 'hi'], 'text.length_signal': [5, 2]}] | |
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 | |
# Filter by a specific rowid. | |
filters = [(ROWID, 'equals', '2')] | |
result = dataset.select_rows([ROWID, 'text', udf_col], filters=filters) | |
assert list(result) == [{ | |
ROWID: '2', | |
'text': ['everybody', 'bye', 'test'], | |
'text.length_signal': [9, 3, 4] | |
}] | |
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 + 3 | |
def test_udf_deeply_nested(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
'text': [['hello'], ['hi', 'bye']] | |
}, { | |
'text': [['everybody', 'bye'], ['test']] | |
}]) | |
udf_col = Column(('text', '*', '*'), signal_udf=LengthSignal()) | |
result = dataset.select_rows([udf_col]) | |
assert list(result) == [{ | |
'text.length_signal': [[5], [2, 3]] | |
}, { | |
'text.length_signal': [[9, 3], [4]] | |
}] | |
assert cast(LengthSignal, udf_col.signal_udf)._call_count == 6 | |
def test_udf_with_embedding(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
'text': 'hello.', | |
}, { | |
'text': 'hello2.', | |
}]) | |
dataset.compute_signal(TestEmbedding(), 'text') | |
signal_col = Column('text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding')) | |
result = dataset.select_rows(['text', signal_col]) | |
expected_result: list[Item] = [{ | |
'text': 'hello.', | |
'text.test_embedding_sum(embedding=test_embedding)': 1.0 | |
}, { | |
'text': 'hello2.', | |
'text.test_embedding_sum(embedding=test_embedding)': 2.0 | |
}] | |
assert list(result) == expected_result | |
# Select rows with alias. | |
signal_col = Column( | |
'text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'), alias='emb_sum') | |
result = dataset.select_rows(['text', signal_col]) | |
expected_result = [{'text': 'hello.', 'emb_sum': 1.0}, {'text': 'hello2.', 'emb_sum': 2.0}] | |
assert list(result) == expected_result | |
def test_udf_with_nested_embedding(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
'text': ['hello.', 'hello world.'], | |
}, { | |
'text': ['hello world2.', 'hello2.'], | |
}]) | |
dataset.compute_signal(TestEmbedding(), ('text', '*')) | |
signal_col = Column(('text', '*'), signal_udf=TestEmbeddingSumSignal(embedding='test_embedding')) | |
result = dataset.select_rows([('text', '*'), signal_col]) | |
expected_result = [{ | |
'text.*': ['hello.', 'hello world.'], | |
'text.test_embedding_sum(embedding=test_embedding)': [1.0, 3.0] | |
}, { | |
'text.*': ['hello world2.', 'hello2.'], | |
'text.test_embedding_sum(embedding=test_embedding)': [4.0, 2.0] | |
}] | |
assert list(result) == expected_result | |
def test_udf_throws_without_precomputing(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
'text': 'hello.', | |
}, { | |
'text': 'hello2.', | |
}]) | |
# Embedding is not precomputed, yet we ask for the embedding. | |
signal_col = Column('text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding')) | |
with pytest.raises(ValueError, match="No embedding found for path \\('text',\\)"): | |
dataset.select_rows(['text', signal_col]) | |
class TestSplitter(TextSplitterSignal): | |
"""Split documents into sentence by splitting on period.""" | |
name = 'test_splitter' | |
def compute(self, data: Iterable[RichData]) -> Iterable[Item]: | |
for text in data: | |
if not isinstance(text, str): | |
raise ValueError(f'Expected text to be a string, got {type(text)} instead.') | |
result: list[Item] = [] | |
for sentence in text.split('.'): | |
start = text.index(sentence) | |
end = start + len(sentence) | |
result.append(lilac_span(start, end)) | |
yield result | |
def test_udf_after_precomputed_split(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
'text': 'sentence 1. sentence 2 is longer', | |
}, { | |
'text': 'sentence 1 is longer. sent2 is short', | |
}]) | |
dataset.compute_signal(TestSplitter(), 'text') | |
udf = Column('text', signal_udf=LengthSignal()) | |
result = dataset.select_rows(['*', udf], combine_columns=True) | |
assert list(result) == [{ | |
'text': enriched_item('sentence 1. sentence 2 is longer', { | |
'length_signal': 32, | |
'test_splitter': [lilac_span(0, 10), lilac_span(11, 32)] | |
}) | |
}, { | |
'text': enriched_item('sentence 1 is longer. sent2 is short', { | |
'length_signal': 36, | |
'test_splitter': [lilac_span(0, 20), lilac_span(21, 36)] | |
}) | |
}] | |
def test_is_computed_signal_key(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
'text': 'hello.', | |
}, { | |
'text': 'hello2.', | |
}]) | |
signal_col = Column('text', signal_udf=ComputedKeySignal()) | |
result = dataset.select_rows(['text', signal_col]) | |
assert list(result) == [{ | |
'text': 'hello.', | |
'text.key_False': 1 | |
}, { | |
'text': 'hello2.', | |
'text.key_False': 1 | |
}] | |