"""Implementation-agnostic tests of the Dataset DB API.""" from typing import Iterable, Optional, cast import numpy as np import pytest from typing_extensions import override from ..schema import UUID_COLUMN, VALUE_KEY, Field, Item, RichData, field, schema from ..signals.signal import TextEmbeddingSignal, TextSignal, clear_signal_registry, register_signal from .dataset import Column, DatasetManifest, val from .dataset_test_utils import TEST_DATASET_NAME, TEST_NAMESPACE, TestDataMaker, enriched_item from .dataset_utils import lilac_embedding SIMPLE_ITEMS: list[Item] = [{ UUID_COLUMN: '1', 'str': 'a', 'int': 1, 'bool': False, 'float': 3.0 }, { UUID_COLUMN: '2', 'str': 'b', 'int': 2, 'bool': True, 'float': 2.0 }, { UUID_COLUMN: '3', 'str': 'b', 'int': 2, 'bool': True, 'float': 1.0 }] EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]), ('hello2.', [1.0, 1.0, 0.0]), ('hello world.', [1.0, 1.0, 1.0]), ('hello world2.', [2.0, 1.0, 1.0])] STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} class TestEmbedding(TextEmbeddingSignal): """A test embed function.""" name = 'test_embedding' @override def compute(self, data: Iterable[RichData]) -> Iterable[Item]: """Call the embedding function.""" for example in data: yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))] class LengthSignal(TextSignal): name = 'length_signal' _call_count: int = 0 def fields(self) -> Field: return field('int32') def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: for text_content in data: self._call_count += 1 yield len(text_content) class TestSignal(TextSignal): name = 'test_signal' @override def fields(self) -> Field: return field(fields={'len': 'int32', 'flen': 'float32'}) @override def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data] @pytest.fixture(scope='module', autouse=True) def setup_teardown() -> Iterable[None]: # Setup. register_signal(TestSignal) register_signal(LengthSignal) register_signal(SignalWithQuoteInIt) register_signal(SignalWithDoubleQuoteInIt) # Unit test runs. yield # Teardown. clear_signal_registry() def test_select_all_columns(make_test_data: TestDataMaker) -> None: dataset = make_test_data(SIMPLE_ITEMS) result = dataset.select_rows() assert list(result) == SIMPLE_ITEMS def test_select_subcols_with_dot_seperator(make_test_data: TestDataMaker) -> None: items: list[Item] = [{ UUID_COLUMN: '1', 'people': [{ 'name': 'A', 'address': { 'zip': 1 } }, { 'name': 'B', 'address': { 'zip': 2 } }] }, { UUID_COLUMN: '2', 'people': [{ 'name': 'C', 'address': { 'zip': 3 } }] }] dataset = make_test_data(items) result = dataset.select_rows(['people.*.name', 'people.*.address.zip']) assert list(result) == [{ UUID_COLUMN: '1', 'people.*.name': ['A', 'B'], 'people.*.address.zip': [1, 2] }, { UUID_COLUMN: '2', 'people.*.name': ['C'], 'people.*.address.zip': [3] }] result = dataset.select_rows(['people.*.address.zip'], combine_columns=True) assert list(result) == [{ UUID_COLUMN: '1', 'people': [{ 'address': { 'zip': 1 } }, { 'address': { 'zip': 2 } }] }, { UUID_COLUMN: '2', 'people': [{ 'address': { 'zip': 3 } }] }] result = dataset.select_rows(['people']) assert list(result) == items def test_select_subcols_with_escaped_dot(make_test_data: TestDataMaker) -> None: items: list[Item] = [{ UUID_COLUMN: '1', 'people.new': [{ 'name': 'A' }, { 'name': 'B' }] }, { UUID_COLUMN: '2', 'people.new': [{ 'name': 'C' }] }] dataset = make_test_data(items) result = dataset.select_rows(['"people.new".*.name']) assert list(result) == [{ UUID_COLUMN: '1', 'people.new.*.name': ['A', 'B'], }, { UUID_COLUMN: '2', 'people.new.*.name': ['C'], }] # Escape name even though it does not need to be. result = dataset.select_rows(['"people.new".*."name"']) assert list(result) == [{ UUID_COLUMN: '1', 'people.new.*.name': ['A', 'B'], }, { UUID_COLUMN: '2', 'people.new.*.name': ['C'], }] def test_select_star(make_test_data: TestDataMaker) -> None: items: list[Item] = [{ UUID_COLUMN: '1', 'name': 'A', 'info': { 'age': 40 } }, { UUID_COLUMN: '2', 'name': 'B', 'info': { 'age': 42 } }] dataset = make_test_data(items) # Select *. result = dataset.select_rows(['*']) assert list(result) == items # Select (*,). result = dataset.select_rows([('*',)]) assert list(result) == items # Select *, plus a redundant `info` column. result = dataset.select_rows(['*', 'info']) assert list(result) == [{ UUID_COLUMN: '1', 'name': 'A', 'info': { 'age': 40 }, 'info_2': { 'age': 40 }, }, { UUID_COLUMN: '2', 'name': 'B', 'info': { 'age': 42 }, 'info_2': { 'age': 42 }, }] # Select * plus an inner `info.age` column. result = dataset.select_rows(['*', ('info', 'age')]) assert list(result) == [{ UUID_COLUMN: '1', 'name': 'A', 'info': { 'age': 40 }, 'info.age': 40 }, { UUID_COLUMN: '2', 'name': 'B', 'info': { 'age': 42 }, 'info.age': 42 }] def test_select_star_with_combine_cols(make_test_data: TestDataMaker) -> None: items: list[Item] = [{ UUID_COLUMN: '1', 'name': 'A', 'info': { 'age': 40 } }, { UUID_COLUMN: '2', 'name': 'B', 'info': { 'age': 42 } }] dataset = make_test_data(items) # Select *. result = dataset.select_rows(['*'], combine_columns=True) assert list(result) == items # Select *, plus a redundant `info` column. result = dataset.select_rows(['*', 'info'], combine_columns=True) assert list(result) == items # Select * plus an inner `info.age` column. result = dataset.select_rows(['*', ('info', 'age')], combine_columns=True) assert list(result) == items # Select *, plus redundant `name`, plus a udf. udf = Column('name', signal_udf=TestSignal()) result = dataset.select_rows(['*', 'name', udf], combine_columns=True) assert list(result) == [{ UUID_COLUMN: '1', 'name': enriched_item('A', {'test_signal': { 'len': 1, 'flen': 1.0 }}), 'info': { 'age': 40 } }, { UUID_COLUMN: '2', 'name': enriched_item('B', {'test_signal': { 'len': 1, 'flen': 1.0 }}), 'info': { 'age': 42 } }] def test_select_ids(make_test_data: TestDataMaker) -> None: dataset = make_test_data(SIMPLE_ITEMS) result = dataset.select_rows([UUID_COLUMN]) assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}] def test_select_ids_with_limit_and_offset(make_test_data: TestDataMaker) -> None: items: list[Item] = [{UUID_COLUMN: str(i)} for i in range(10, 20)] dataset = make_test_data(items) result = dataset.select_rows([UUID_COLUMN], offset=1, limit=3) assert list(result) == [{UUID_COLUMN: '11'}, {UUID_COLUMN: '12'}, {UUID_COLUMN: '13'}] result = dataset.select_rows([UUID_COLUMN], offset=7, limit=2) assert list(result) == [{UUID_COLUMN: '17'}, {UUID_COLUMN: '18'}] result = dataset.select_rows([UUID_COLUMN], offset=9, limit=200) assert list(result) == [{UUID_COLUMN: '19'}] result = dataset.select_rows([UUID_COLUMN], offset=10, limit=200) assert list(result) == [] def test_columns(make_test_data: TestDataMaker) -> None: dataset = make_test_data(SIMPLE_ITEMS) result = dataset.select_rows(['str', 'float']) assert list(result) == [{ UUID_COLUMN: '1', 'str': 'a', 'float': 3.0 }, { UUID_COLUMN: '2', 'str': 'b', 'float': 2.0 }, { UUID_COLUMN: '3', 'str': 'b', 'float': 1.0 }] def test_merge_values(make_test_data: TestDataMaker) -> None: dataset = make_test_data([{ UUID_COLUMN: '1', 'text': 'hello' }, { UUID_COLUMN: '2', 'text': 'everybody' }]) test_signal = TestSignal() dataset.compute_signal(test_signal, 'text') length_signal = LengthSignal() dataset.compute_signal(length_signal, 'text') result = dataset.select_rows(['text']) assert list(result) == [{ UUID_COLUMN: '1', 'text': enriched_item('hello', { 'length_signal': 5, 'test_signal': { 'len': 5, 'flen': 5.0 } }) }, { UUID_COLUMN: '2', 'text': enriched_item('everybody', { 'length_signal': 9, 'test_signal': { 'len': 9, 'flen': 9.0 } }), }] # Test subselection. result = dataset.select_rows( [val('text'), ('text', 'test_signal', 'flen'), ('text', 'test_signal', 'len')]) assert list(result) == [{ UUID_COLUMN: '1', f'text.{VALUE_KEY}': 'hello', 'text.test_signal.flen': 5.0, 'text.test_signal.len': 5 }, { UUID_COLUMN: '2', f'text.{VALUE_KEY}': 'everybody', 'text.test_signal.flen': 9.0, 'text.test_signal.len': 9 }] # Test subselection with combine_columns=True. result = dataset.select_rows( ['text', ('text', 'test_signal', 'flen'), ('text', 'test_signal', 'len')], combine_columns=True) assert list(result) == [{ UUID_COLUMN: '1', 'text': enriched_item('hello', { 'length_signal': 5, 'test_signal': { 'len': 5, 'flen': 5.0 } }) }, { UUID_COLUMN: '2', 'text': enriched_item('everybody', { 'length_signal': 9, 'test_signal': { 'len': 9, 'flen': 9.0 } }), }] # Test subselection with aliasing. result = dataset.select_rows( columns=[val('text'), Column(('text', 'test_signal', 'len'), alias='metadata')]) assert list(result) == [{ UUID_COLUMN: '1', f'text.{VALUE_KEY}': 'hello', 'metadata': 5 }, { UUID_COLUMN: '2', f'text.{VALUE_KEY}': 'everybody', 'metadata': 9 }] result = dataset.select_rows(columns=[Column(('text'), alias='text_enrichment')]) assert list(result) == [{ UUID_COLUMN: '1', 'text_enrichment': enriched_item('hello', { 'length_signal': 5, 'test_signal': { 'len': 5, 'flen': 5.0 } }) }, { UUID_COLUMN: '2', 'text_enrichment': enriched_item('everybody', { 'length_signal': 9, 'test_signal': { 'len': 9, 'flen': 9.0 } }) }] def test_merge_array_values(make_test_data: TestDataMaker) -> None: dataset = make_test_data([{ UUID_COLUMN: '1', 'texts': ['hello', 'everybody'] }, { UUID_COLUMN: '2', 'texts': ['a', 'bc', 'def'] }]) test_signal = TestSignal() dataset.compute_signal(test_signal, ('texts', '*')) length_signal = LengthSignal() dataset.compute_signal(length_signal, ('texts', '*')) assert dataset.manifest() == DatasetManifest( namespace=TEST_NAMESPACE, dataset_name=TEST_DATASET_NAME, data_schema=schema({ UUID_COLUMN: 'string', 'texts': [ field( 'string', fields={ 'length_signal': field('int32', length_signal.dict()), 'test_signal': field( signal=test_signal.dict(), fields={ 'len': 'int32', 'flen': 'float32' }) }) ], }), num_items=2) result = dataset.select_rows(['texts']) assert list(result) == [{ UUID_COLUMN: '1', 'texts': [ enriched_item('hello', { 'length_signal': 5, 'test_signal': { 'len': 5, 'flen': 5.0 } }), enriched_item('everybody', { 'length_signal': 9, 'test_signal': { 'len': 9, 'flen': 9.0 } }) ], }, { UUID_COLUMN: '2', 'texts': [ enriched_item('a', { 'length_signal': 1, 'test_signal': { 'len': 1, 'flen': 1.0 } }), enriched_item('bc', { 'length_signal': 2, 'test_signal': { 'len': 2, 'flen': 2.0 } }), enriched_item('def', { 'length_signal': 3, 'test_signal': { 'len': 3, 'flen': 3.0 } }) ], }] # Test subselection. result = dataset.select_rows( [val(('texts', '*')), ('texts', '*', 'length_signal'), ('texts', '*', 'test_signal', 'flen')]) assert list(result) == [{ UUID_COLUMN: '1', f'texts.*.{VALUE_KEY}': ['hello', 'everybody'], 'texts.*.test_signal.flen': [5.0, 9.0], 'texts.*.length_signal': [5, 9] }, { UUID_COLUMN: '2', f'texts.*.{VALUE_KEY}': ['a', 'bc', 'def'], 'texts.*.test_signal.flen': [1.0, 2.0, 3.0], 'texts.*.length_signal': [1, 2, 3] }] def test_combining_columns(make_test_data: TestDataMaker) -> None: dataset = make_test_data([{ UUID_COLUMN: '1', 'text': 'hello', 'extra': { 'text': { 'length_signal': 5, 'test_signal': { 'len': 5, 'flen': 5.0 } } } }, { UUID_COLUMN: '2', 'text': 'everybody', 'extra': { 'text': { 'length_signal': 9, 'test_signal': { 'len': 9, 'flen': 9.0 } } } }]) # Sub-select text and test_signal. result = dataset.select_rows(['text', ('extra', 'text', 'test_signal')], combine_columns=True) assert list(result) == [{ UUID_COLUMN: '1', 'text': 'hello', 'extra': { 'text': { 'test_signal': { 'len': 5, 'flen': 5.0 } } } }, { UUID_COLUMN: '2', 'text': 'everybody', 'extra': { 'text': { 'test_signal': { 'len': 9, 'flen': 9.0 } } } }] # Sub-select text and length_signal. result = dataset.select_rows(['text', ('extra', 'text', 'length_signal')], combine_columns=True) assert list(result) == [{ UUID_COLUMN: '1', 'text': 'hello', 'extra': { 'text': { 'length_signal': 5 } } }, { UUID_COLUMN: '2', 'text': 'everybody', 'extra': { 'text': { 'length_signal': 9 } } }] # Sub-select length_signal only. result = dataset.select_rows([('extra', 'text', 'length_signal')], combine_columns=True) assert list(result) == [{ UUID_COLUMN: '1', 'extra': { 'text': { 'length_signal': 5 } } }, { UUID_COLUMN: '2', 'extra': { 'text': { 'length_signal': 9 } } }] # Aliases are ignored when combing columns. len_col = Column(('extra', 'text', 'length_signal'), alias='hello') result = dataset.select_rows([len_col], combine_columns=True) assert list(result) == [{ UUID_COLUMN: '1', 'extra': { 'text': { 'length_signal': 5 } } }, { UUID_COLUMN: '2', 'extra': { 'text': { 'length_signal': 9 } } }] # Works with UDFs and aliases are ignored. udf_col = Column('text', alias='ignored', signal_udf=LengthSignal()) result = dataset.select_rows(['text', udf_col], combine_columns=True) assert list(result) == [{ UUID_COLUMN: '1', 'text': enriched_item('hello', {'length_signal': 5}) }, { UUID_COLUMN: '2', 'text': enriched_item('everybody', {'length_signal': 9}) }] def test_source_joined_with_named_signal(make_test_data: TestDataMaker) -> None: dataset = make_test_data(SIMPLE_ITEMS) assert dataset.manifest() == DatasetManifest( namespace=TEST_NAMESPACE, dataset_name=TEST_DATASET_NAME, data_schema=schema({ UUID_COLUMN: 'string', 'str': 'string', 'int': 'int32', 'bool': 'boolean', 'float': 'float32', }), num_items=3) test_signal = TestSignal() dataset.compute_signal(test_signal, 'str') # Check the enriched dataset manifest has 'text' enriched. assert dataset.manifest() == DatasetManifest( namespace=TEST_NAMESPACE, dataset_name=TEST_DATASET_NAME, data_schema=schema({ UUID_COLUMN: 'string', 'str': field( 'string', fields={ 'test_signal': field( signal=test_signal.dict(), fields={ 'len': 'int32', 'flen': 'float32' }) }), 'int': 'int32', 'bool': 'boolean', 'float': 'float32', }), num_items=3) # Select both columns, without val() on str. result = dataset.select_rows(['str', Column(('str', 'test_signal'), alias='test_signal_on_str')]) assert list(result) == [{ UUID_COLUMN: '1', 'str': enriched_item('a', {'test_signal': { 'len': 1, 'flen': 1.0 }}), 'test_signal_on_str': { 'len': 1, 'flen': 1.0 } }, { UUID_COLUMN: '2', 'str': enriched_item('b', {'test_signal': { 'len': 1, 'flen': 1.0 }}), 'test_signal_on_str': { 'len': 1, 'flen': 1.0 } }, { UUID_COLUMN: '3', 'str': enriched_item('b', {'test_signal': { 'len': 1, 'flen': 1.0 }}), 'test_signal_on_str': { 'len': 1, 'flen': 1.0 } }] # Select both columns, with val() on str. result = dataset.select_rows( [val('str'), Column(('str', 'test_signal'), alias='test_signal_on_str')]) assert list(result) == [{ UUID_COLUMN: '1', f'str.{VALUE_KEY}': 'a', 'test_signal_on_str': { 'len': 1, 'flen': 1.0 } }, { UUID_COLUMN: '2', f'str.{VALUE_KEY}': 'b', 'test_signal_on_str': { 'len': 1, 'flen': 1.0 } }, { UUID_COLUMN: '3', f'str.{VALUE_KEY}': 'b', 'test_signal_on_str': { 'len': 1, 'flen': 1.0 } }] def test_invalid_column_paths(make_test_data: TestDataMaker) -> None: dataset = make_test_data([{ UUID_COLUMN: '1', 'text': enriched_item('hello', {'test_signal': { 'len': 5 }}), 'text2': [ enriched_item('hello', {'test_signal': { 'len': 5 }}), enriched_item('hi', {'test_signal': { 'len': 2 }}) ], }]) with pytest.raises(ValueError, match='Path part "invalid" not found in the dataset'): dataset.select_rows([('text', 'test_signal', 'invalid')]) with pytest.raises(ValueError, match='Selecting a specific index of a repeated field'): dataset.select_rows([('text2', '4', 'test_signal')]) def test_signal_with_quote(make_test_data: TestDataMaker) -> None: dataset = make_test_data([{ UUID_COLUMN: '1', 'text': 'hello', }, { UUID_COLUMN: '2', 'text': 'world', }]) dataset.compute_signal(SignalWithQuoteInIt(), 'text') dataset.compute_signal(SignalWithDoubleQuoteInIt(), 'text') result = dataset.select_rows(['text']) assert list(result) == [{ UUID_COLUMN: '1', 'text': enriched_item('hello', { "test'signal": True, 'test"signal': True }) }, { UUID_COLUMN: '2', 'text': enriched_item('world', { "test'signal": True, 'test"signal': True }), }] class SignalWithQuoteInIt(TextSignal): name = "test'signal" @override def fields(self) -> Field: return field('boolean') @override def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: for d in data: yield True class SignalWithDoubleQuoteInIt(TextSignal): name = 'test"signal' @override def fields(self) -> Field: return field('boolean') @override def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: for d in data: yield True