Spaces:
Runtime error
Runtime error
"""Tests for dataset.select_rows(sort_by=...).""" | |
from typing import Iterable, Optional, Sequence, cast | |
import numpy as np | |
import pytest | |
from typing_extensions import override | |
from ..embeddings.vector_store import VectorStore | |
from ..schema import UUID_COLUMN, Field, Item, RichData, VectorKey, field | |
from ..signals.signal import ( | |
TextEmbeddingModelSignal, | |
TextEmbeddingSignal, | |
TextSignal, | |
clear_signal_registry, | |
register_signal, | |
) | |
from .dataset import BinaryOp, Column, SortOrder | |
from .dataset_test_utils import TestDataMaker, enriched_item | |
from .dataset_utils import lilac_embedding | |
class TestSignal(TextSignal): | |
name = 'test_signal' | |
def fields(self) -> Field: | |
return field(fields={'len': 'int32', 'is_all_cap': 'boolean'}) | |
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: | |
for text_content in data: | |
yield {'len': len(text_content), 'is_all_cap': text_content.isupper()} | |
class TestPrimitiveSignal(TextSignal): | |
name = 'primitive_signal' | |
def fields(self) -> Field: | |
return field('int32') | |
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: | |
for text_content in data: | |
yield len(text_content) + 1 | |
class NestedArraySignal(TextSignal): | |
name = 'nested_array' | |
def fields(self) -> Field: | |
return field(fields=[['int32']]) | |
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: | |
for text_content in data: | |
yield [[len(text_content) + 1], [len(text_content)]] | |
def setup_teardown() -> Iterable[None]: | |
# Setup. | |
register_signal(TestSignal) | |
register_signal(TestPrimitiveSignal) | |
register_signal(NestedArraySignal) | |
register_signal(TopKEmbedding) | |
# Unit test runs. | |
yield | |
# Teardown. | |
clear_signal_registry() | |
def test_sort_by_source_no_alias_no_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'erased': True, | |
'score': 4.1, | |
'document': { | |
'num_pages': 4, | |
'header': { | |
'title': 'c' | |
} | |
} | |
}, { | |
UUID_COLUMN: '2', | |
'erased': False, | |
'score': 3.5, | |
'document': { | |
'num_pages': 5, | |
'header': { | |
'title': 'b' | |
} | |
}, | |
}, { | |
UUID_COLUMN: '3', | |
'erased': True, | |
'score': 3.7, | |
'document': { | |
'num_pages': 3, | |
'header': { | |
'title': 'a' | |
} | |
}, | |
}]) | |
# Sort by bool. | |
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['erased'], sort_order=SortOrder.ASC) | |
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}] | |
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['erased'], sort_order=SortOrder.DESC) | |
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}] | |
# Sort by float. | |
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['score'], sort_order=SortOrder.ASC) | |
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}] | |
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['score'], sort_order=SortOrder.DESC) | |
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}] | |
# Sort by nested int. | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN], sort_by=['document.num_pages'], sort_order=SortOrder.ASC) | |
assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}] | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN], sort_by=['document.num_pages'], sort_order=SortOrder.DESC) | |
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}] | |
# Sort by double nested string. | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN], sort_by=['document.header.title'], sort_order=SortOrder.ASC) | |
assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}] | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN], sort_by=['document.header.title'], sort_order=SortOrder.DESC) | |
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}] | |
def test_sort_by_signal_no_alias_no_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'text': 'HEY' | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'everyone' | |
}, { | |
UUID_COLUMN: '3', | |
'text': 'HI' | |
}]) | |
dataset.compute_signal(TestSignal(), 'text') | |
# Sort by `signal.len`. | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN], sort_by=['text.test_signal.len'], sort_order=SortOrder.ASC) | |
assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}] | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN], sort_by=['text.test_signal.len'], sort_order=SortOrder.DESC) | |
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}] | |
# Sort by `signal.is_all_cap`. | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN], sort_by=['text.test_signal.is_all_cap'], sort_order=SortOrder.ASC) | |
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}] | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN], sort_by=['text.test_signal.is_all_cap'], sort_order=SortOrder.DESC) | |
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}] | |
def test_sort_by_signal_alias_no_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'text': 'HEY' | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'everyone' | |
}, { | |
UUID_COLUMN: '3', | |
'text': 'HI' | |
}]) | |
dataset.compute_signal(TestSignal(), 'text') | |
# Sort by `signal.len`. | |
signal_alias = Column('text.test_signal', alias='signal') | |
result = dataset.select_rows( | |
columns=[signal_alias], sort_by=['signal.len'], sort_order=SortOrder.ASC) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'signal': { | |
'len': 2, | |
'is_all_cap': True | |
} | |
}, { | |
UUID_COLUMN: '1', | |
'signal': { | |
'len': 3, | |
'is_all_cap': True | |
} | |
}, { | |
UUID_COLUMN: '2', | |
'signal': { | |
'len': 8, | |
'is_all_cap': False | |
} | |
}] | |
result = dataset.select_rows( | |
columns=[signal_alias], sort_by=['signal.len'], sort_order=SortOrder.DESC) | |
assert list(result) == [{ | |
UUID_COLUMN: '2', | |
'signal': { | |
'len': 8, | |
'is_all_cap': False | |
} | |
}, { | |
UUID_COLUMN: '1', | |
'signal': { | |
'len': 3, | |
'is_all_cap': True | |
} | |
}, { | |
UUID_COLUMN: '3', | |
'signal': { | |
'len': 2, | |
'is_all_cap': True | |
} | |
}] | |
def test_sort_by_enriched_alias_no_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'text': 'HEY' | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'everyone' | |
}, { | |
UUID_COLUMN: '3', | |
'text': 'HI' | |
}]) | |
dataset.compute_signal(TestSignal(), 'text') | |
# Sort by `document.test_signal.is_all_cap` where 'document' is an alias to 'text'. | |
text_alias = Column('text', alias='document') | |
result = dataset.select_rows( | |
columns=[text_alias], sort_by=['document.test_signal.is_all_cap'], sort_order=SortOrder.ASC) | |
assert list(result) == [{ | |
UUID_COLUMN: '2', | |
'document': enriched_item('everyone', {'test_signal': { | |
'len': 8, | |
'is_all_cap': False | |
}}) | |
}, { | |
UUID_COLUMN: '1', | |
'document': enriched_item('HEY', {'test_signal': { | |
'len': 3, | |
'is_all_cap': True | |
}}) | |
}, { | |
UUID_COLUMN: '3', | |
'document': enriched_item('HI', {'test_signal': { | |
'len': 2, | |
'is_all_cap': True | |
}}) | |
}] | |
result = dataset.select_rows( | |
columns=[text_alias], sort_by=['document.test_signal.is_all_cap'], sort_order=SortOrder.DESC) | |
assert list(result) == [{ | |
UUID_COLUMN: '1', | |
'document': enriched_item('HEY', {'test_signal': { | |
'len': 3, | |
'is_all_cap': True | |
}}) | |
}, { | |
UUID_COLUMN: '3', | |
'document': enriched_item('HI', {'test_signal': { | |
'len': 2, | |
'is_all_cap': True | |
}}) | |
}, { | |
UUID_COLUMN: '2', | |
'document': enriched_item('everyone', {'test_signal': { | |
'len': 8, | |
'is_all_cap': False | |
}}) | |
}] | |
def test_sort_by_udf_alias_no_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'text': 'HEY' | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'everyone' | |
}, { | |
UUID_COLUMN: '3', | |
'text': 'HI' | |
}]) | |
# Equivalent to: SELECT `TestSignal(text) AS udf`. | |
text_udf = Column('text', signal_udf=TestSignal(), alias='udf') | |
# Sort by `udf.len`, where `udf` is an alias to `TestSignal(text)`. | |
result = dataset.select_rows(['*', text_udf], sort_by=['udf.len'], sort_order=SortOrder.ASC) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'text': 'HI', | |
'udf': { | |
'len': 2, | |
'is_all_cap': True | |
} | |
}, { | |
UUID_COLUMN: '1', | |
'text': 'HEY', | |
'udf': { | |
'len': 3, | |
'is_all_cap': True | |
} | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'everyone', | |
'udf': { | |
'len': 8, | |
'is_all_cap': False | |
} | |
}] | |
def test_sort_by_udf_no_alias_no_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'text': 'HEY' | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'everyone' | |
}, { | |
UUID_COLUMN: '3', | |
'text': 'HI' | |
}]) | |
text_udf = Column('text', signal_udf=TestSignal()) | |
# Sort by `text.test_signal.len`, produced by executing the udf `TestSignal(text)`. | |
result = dataset.select_rows(['*', text_udf], | |
sort_by=[('text', 'test_signal', 'len')], | |
sort_order=SortOrder.ASC, | |
combine_columns=True) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'text': enriched_item('HI', {'test_signal': { | |
'len': 2, | |
'is_all_cap': True | |
}}), | |
}, { | |
UUID_COLUMN: '1', | |
'text': enriched_item('HEY', {'test_signal': { | |
'len': 3, | |
'is_all_cap': True | |
}}), | |
}, { | |
UUID_COLUMN: '2', | |
'text': enriched_item('everyone', {'test_signal': { | |
'len': 8, | |
'is_all_cap': False | |
}}), | |
}] | |
# Sort descending. | |
result = dataset.select_rows(['*', text_udf], | |
sort_by=[('text', 'test_signal', 'len')], | |
sort_order=SortOrder.DESC, | |
combine_columns=True) | |
assert list(result) == [{ | |
UUID_COLUMN: '2', | |
'text': enriched_item('everyone', {'test_signal': { | |
'len': 8, | |
'is_all_cap': False | |
}}), | |
}, { | |
UUID_COLUMN: '1', | |
'text': enriched_item('HEY', {'test_signal': { | |
'len': 3, | |
'is_all_cap': True | |
}}), | |
}, { | |
UUID_COLUMN: '3', | |
'text': enriched_item('HI', {'test_signal': { | |
'len': 2, | |
'is_all_cap': True | |
}}), | |
}] | |
def test_sort_by_primitive_udf_alias_no_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'text': 'HEY' | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'everyone' | |
}, { | |
UUID_COLUMN: '3', | |
'text': 'HI' | |
}]) | |
# Equivalent to: SELECT `TestPrimitiveSignal(text) AS udf`. | |
text_udf = Column('text', signal_udf=TestPrimitiveSignal(), alias='udf') | |
# Sort by the primitive value returned by the udf. | |
result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.ASC) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'text': 'HI', | |
'udf': 3 | |
}, { | |
UUID_COLUMN: '1', | |
'text': 'HEY', | |
'udf': 4 | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'everyone', | |
'udf': 9 | |
}] | |
def test_sort_by_source_non_leaf_errors(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'vals': [7, 1] | |
}, { | |
UUID_COLUMN: '2', | |
'vals': [3, 4] | |
}, { | |
UUID_COLUMN: '3', | |
'vals': [9, 0] | |
}]) | |
# Sort by repeated. | |
with pytest.raises(ValueError, match='Unable to sort by path'): | |
dataset.select_rows(columns=[UUID_COLUMN], sort_by=['vals'], sort_order=SortOrder.ASC) | |
def test_sort_by_source_no_alias_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'vals': [[{ | |
'score': 7 | |
}, { | |
'score': 1 | |
}], [{ | |
'score': 1 | |
}, { | |
'score': 7 | |
}]] | |
}, { | |
UUID_COLUMN: '2', | |
'vals': [[{ | |
'score': 3 | |
}, { | |
'score': 4 | |
}]] | |
}, { | |
UUID_COLUMN: '3', | |
'vals': [[{ | |
'score': 9 | |
}, { | |
'score': 0 | |
}]] | |
}]) | |
# Sort by repeated 'vals'. | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN, 'vals'], sort_by=['vals.*.*.score'], sort_order=SortOrder.ASC) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'vals': [[{ | |
'score': 9 | |
}, { | |
'score': 0 | |
}]] | |
}, { | |
UUID_COLUMN: '1', | |
'vals': [[{ | |
'score': 7 | |
}, { | |
'score': 1 | |
}], [{ | |
'score': 1 | |
}, { | |
'score': 7 | |
}]] | |
}, { | |
UUID_COLUMN: '2', | |
'vals': [[{ | |
'score': 3 | |
}, { | |
'score': 4 | |
}]] | |
}] | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN, 'vals'], sort_by=['vals.*.*.score'], sort_order=SortOrder.DESC) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'vals': [[{ | |
'score': 9 | |
}, { | |
'score': 0 | |
}]] | |
}, { | |
UUID_COLUMN: '1', | |
'vals': [[{ | |
'score': 7 | |
}, { | |
'score': 1 | |
}], [{ | |
'score': 1 | |
}, { | |
'score': 7 | |
}]] | |
}, { | |
UUID_COLUMN: '2', | |
'vals': [[{ | |
'score': 3 | |
}, { | |
'score': 4 | |
}]] | |
}] | |
def test_sort_by_source_alias_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'vals': [[7, 1], [1, 7]] | |
}, { | |
UUID_COLUMN: '2', | |
'vals': [[3], [11]] | |
}, { | |
UUID_COLUMN: '3', | |
'vals': [[9, 0]] | |
}]) | |
# Sort by repeated 'vals'. | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN, Column('vals', alias='scores')], | |
sort_by=['scores.*.*'], | |
sort_order=SortOrder.ASC) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'scores': [[9, 0]] | |
}, { | |
UUID_COLUMN: '1', | |
'scores': [[7, 1], [1, 7]] | |
}, { | |
UUID_COLUMN: '2', | |
'scores': [[3], [11]] | |
}] | |
result = dataset.select_rows( | |
columns=[UUID_COLUMN, Column('vals', alias='scores')], | |
sort_by=['scores.*.*'], | |
sort_order=SortOrder.DESC) | |
assert list(result) == [{ | |
UUID_COLUMN: '2', | |
'scores': [[3], [11]] | |
}, { | |
UUID_COLUMN: '3', | |
'scores': [[9, 0]] | |
}, { | |
UUID_COLUMN: '1', | |
'scores': [[7, 1], [1, 7]] | |
}] | |
def test_sort_by_udf_alias_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'text': 'HEY' | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'everyone' | |
}, { | |
UUID_COLUMN: '3', | |
'text': 'HI' | |
}]) | |
# Equivalent to: SELECT `NestedArraySignal(text) AS udf`. | |
text_udf = Column('text', signal_udf=NestedArraySignal(), alias='udf') | |
# Sort by `udf.*.*`, where `udf` is an alias to `NestedArraySignal(text)`. | |
result = dataset.select_rows(['*', text_udf], sort_by=['udf.*.*'], sort_order=SortOrder.ASC) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'text': 'HI', | |
'udf': [[3], [2]] | |
}, { | |
UUID_COLUMN: '1', | |
'text': 'HEY', | |
'udf': [[4], [3]] | |
}, { | |
UUID_COLUMN: '2', | |
'text': 'everyone', | |
'udf': [[9], [8]] | |
}] | |
result = dataset.select_rows(['*', text_udf], sort_by=['udf.*.*'], sort_order=SortOrder.DESC) | |
assert list(result) == [{ | |
UUID_COLUMN: '2', | |
'text': 'everyone', | |
'udf': [[9], [8]] | |
}, { | |
UUID_COLUMN: '1', | |
'text': 'HEY', | |
'udf': [[4], [3]] | |
}, { | |
UUID_COLUMN: '3', | |
'text': 'HI', | |
'udf': [[3], [2]] | |
}] | |
def test_sort_by_complex_signal_udf_alias_called_on_repeated(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'texts': [{ | |
'text': 'eardrop' | |
}, { | |
'text': 'I' | |
}] | |
}, { | |
UUID_COLUMN: '2', | |
'texts': [{ | |
'text': 'hey' | |
}, { | |
'text': 'CARS' | |
}] | |
}, { | |
UUID_COLUMN: '3', | |
'texts': [{ | |
'text': 'everyone' | |
}, { | |
'text': '' | |
}] | |
}]) | |
# Equivalent to: SELECT `TestSignal(texts.*.text) AS udf`. | |
texts_udf = Column('texts.*.text', signal_udf=TestSignal(), alias='udf') | |
# Sort by `udf.len`, where `udf` is an alias to `TestSignal(texts.*.text)`. | |
result = dataset.select_rows(['*', texts_udf], | |
sort_by=['udf.len'], | |
sort_order=SortOrder.ASC, | |
combine_columns=True) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'texts': [{ | |
'text': enriched_item('everyone', {'test_signal': { | |
'len': 8, | |
'is_all_cap': False | |
}}) | |
}, { | |
'text': enriched_item('', {'test_signal': { | |
'len': 0, | |
'is_all_cap': False | |
}}) | |
}] | |
}, { | |
UUID_COLUMN: '1', | |
'texts': [{ | |
'text': enriched_item('eardrop', {'test_signal': { | |
'len': 7, | |
'is_all_cap': False | |
}}) | |
}, { | |
'text': enriched_item('I', {'test_signal': { | |
'len': 1, | |
'is_all_cap': True | |
}}) | |
}] | |
}, { | |
UUID_COLUMN: '2', | |
'texts': [{ | |
'text': enriched_item('hey', {'test_signal': { | |
'len': 3, | |
'is_all_cap': False | |
}}) | |
}, { | |
'text': enriched_item('CARS', {'test_signal': { | |
'len': 4, | |
'is_all_cap': True | |
}}) | |
}] | |
}] | |
def test_sort_by_primitive_signal_udf_alias_called_on_repeated( | |
make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'texts': [{ | |
'text': 'eardrop' | |
}, { | |
'text': 'I' | |
}] | |
}, { | |
UUID_COLUMN: '2', | |
'texts': [{ | |
'text': 'hey' | |
}, { | |
'text': 'CARS' | |
}] | |
}, { | |
UUID_COLUMN: '3', | |
'texts': [{ | |
'text': 'everyone' | |
}, { | |
'text': '' | |
}] | |
}]) | |
# Equivalent to: SELECT `TestPrimitiveSignal(texts.*.text) AS udf`. | |
texts_udf = Column('texts.*.text', signal_udf=TestPrimitiveSignal(), alias='udf') | |
# Sort by `udf`, where `udf` is an alias to `TestPrimitiveSignal(texts.*.text)`. | |
result = dataset.select_rows(['*', texts_udf], | |
sort_by=['udf'], | |
sort_order=SortOrder.ASC, | |
combine_columns=True) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'texts': [{ | |
'text': enriched_item('everyone', {'primitive_signal': 9}) | |
}, { | |
'text': enriched_item('', {'primitive_signal': 1}) | |
}] | |
}, { | |
UUID_COLUMN: '1', | |
'texts': [{ | |
'text': enriched_item('eardrop', {'primitive_signal': 8}) | |
}, { | |
'text': enriched_item('I', {'primitive_signal': 2}) | |
}] | |
}, { | |
UUID_COLUMN: '2', | |
'texts': [{ | |
'text': enriched_item('hey', {'primitive_signal': 4}) | |
}, { | |
'text': enriched_item('CARS', {'primitive_signal': 5}) | |
}] | |
}] | |
result = dataset.select_rows(['*', texts_udf], | |
sort_by=['udf'], | |
sort_order=SortOrder.DESC, | |
combine_columns=True) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'texts': [{ | |
'text': enriched_item('everyone', {'primitive_signal': 9}) | |
}, { | |
'text': enriched_item('', {'primitive_signal': 1}) | |
}] | |
}, { | |
UUID_COLUMN: '1', | |
'texts': [{ | |
'text': enriched_item('eardrop', {'primitive_signal': 8}) | |
}, { | |
'text': enriched_item('I', {'primitive_signal': 2}) | |
}] | |
}, { | |
UUID_COLUMN: '2', | |
'texts': [{ | |
'text': enriched_item('hey', {'primitive_signal': 4}) | |
}, { | |
'text': enriched_item('CARS', {'primitive_signal': 5}) | |
}] | |
}] | |
class TopKEmbedding(TextEmbeddingSignal): | |
"""A test embed function.""" | |
name = 'topk_embedding' | |
def compute(self, data: Iterable[RichData]) -> Iterable[Item]: | |
"""Call the embedding function.""" | |
for example in data: | |
example = cast(str, example) | |
emb_spans: list[Item] = [] | |
for i, score in enumerate(example.split('_')): | |
start, end = i * 2, i * 2 + 1 | |
vector = np.array([int(score)]) | |
emb_spans.append(lilac_embedding(start, end, vector)) | |
yield emb_spans | |
class TopKSignal(TextEmbeddingModelSignal): | |
"""Compute scores along a given concept for documents.""" | |
name = 'topk_signal' | |
_query = np.array([1]) | |
def fields(self) -> Field: | |
return field('float32') | |
def vector_compute(self, keys: Iterable[VectorKey], | |
vector_store: VectorStore) -> Iterable[Optional[Item]]: | |
text_embeddings = vector_store.get(keys) | |
dot_products = text_embeddings.dot(self._query).reshape(-1) | |
return dot_products.tolist() | |
def vector_compute_topk( | |
self, | |
topk: int, | |
vector_store: VectorStore, | |
keys: Optional[Iterable[VectorKey]] = None) -> Sequence[tuple[VectorKey, Optional[Item]]]: | |
return vector_store.topk(self._query, topk, keys) | |
def test_sort_by_topk_embedding_udf(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'scores': '8_1', | |
}, { | |
UUID_COLUMN: '2', | |
'scores': '3_5' | |
}, { | |
UUID_COLUMN: '3', | |
'scores': '9_7' | |
}]) | |
dataset.compute_signal(TopKEmbedding(), 'scores') | |
# Equivalent to: SELECT `TopKSignal(scores, embedding='...') AS udf`. | |
text_udf = Column('scores', signal_udf=TopKSignal(embedding='topk_embedding'), alias='udf') | |
# Sort by `udf`, where `udf` is an alias to `TopKSignal(scores, embedding='...')`. | |
result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.DESC, limit=3) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'scores': enriched_item( | |
'9_7', {'topk_embedding': [lilac_embedding(0, 1, None), | |
lilac_embedding(2, 3, None)]}), | |
'udf': [9.0, 7.0] | |
}, { | |
UUID_COLUMN: '1', | |
'scores': enriched_item( | |
'8_1', {'topk_embedding': [lilac_embedding(0, 1, None), | |
lilac_embedding(2, 3, None)]}), | |
'udf': [8.0, 1.0] | |
}] | |
# Same but set limit to 4. | |
result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.DESC, limit=4) | |
assert list(result) == [{ | |
UUID_COLUMN: '3', | |
'scores': enriched_item( | |
'9_7', {'topk_embedding': [lilac_embedding(0, 1, None), | |
lilac_embedding(2, 3, None)]}), | |
'udf': [9.0, 7.0] | |
}, { | |
UUID_COLUMN: '1', | |
'scores': enriched_item( | |
'8_1', {'topk_embedding': [lilac_embedding(0, 1, None), | |
lilac_embedding(2, 3, None)]}), | |
'udf': [8.0, 1.0] | |
}, { | |
UUID_COLUMN: '2', | |
'scores': enriched_item( | |
'3_5', {'topk_embedding': [lilac_embedding(0, 1, None), | |
lilac_embedding(2, 3, None)]}), | |
'udf': [3.0, 5.0] | |
}] | |
def test_sort_by_topk_udf_with_filter(make_test_data: TestDataMaker) -> None: | |
dataset = make_test_data([{ | |
UUID_COLUMN: '1', | |
'scores': '8_1', | |
'active': True | |
}, { | |
UUID_COLUMN: '2', | |
'scores': '3_5', | |
'active': True | |
}, { | |
UUID_COLUMN: '3', | |
'scores': '9_7', | |
'active': False | |
}]) | |
dataset.compute_signal(TopKEmbedding(), 'scores') | |
# Equivalent to: SELECT `TopKSignal(scores, embedding='...') AS udf`. | |
text_udf = Column('scores', signal_udf=TopKSignal(embedding='topk_embedding'), alias='udf') | |
# Sort by `udf`, where `udf` is an alias to `TopKSignal(scores, embedding='...')`. | |
result = dataset.select_rows(['*', text_udf], | |
sort_by=['udf'], | |
filters=[('active', BinaryOp.EQUALS, True)], | |
sort_order=SortOrder.DESC, | |
limit=2) | |
# We make sure that '3' is not in the result, because it is not active, even though it has the | |
# highest topk score. | |
assert list(result) == [{ | |
UUID_COLUMN: '1', | |
'active': True, | |
'scores': enriched_item( | |
'8_1', {'topk_embedding': [lilac_embedding(0, 1, None), | |
lilac_embedding(2, 3, None)]}), | |
'udf': [8.0, 1.0] | |
}, { | |
UUID_COLUMN: '2', | |
'active': True, | |
'scores': enriched_item( | |
'3_5', {'topk_embedding': [lilac_embedding(0, 1, None), | |
lilac_embedding(2, 3, None)]}), | |
'udf': [3.0, 5.0] | |
}] | |