nikhil_staging / src /data /dataset_select_rows_sort_test.py
nsthorat's picture
Push
e4f9cbe
raw
history blame
No virus
24.1 kB
"""Tests for dataset.select_rows(sort_by=...)."""
from typing import Iterable, Optional, Sequence, cast
import numpy as np
import pytest
from typing_extensions import override
from ..embeddings.vector_store import VectorStore
from ..schema import UUID_COLUMN, Field, Item, RichData, VectorKey, field
from ..signals.signal import (
TextEmbeddingModelSignal,
TextEmbeddingSignal,
TextSignal,
clear_signal_registry,
register_signal,
)
from .dataset import BinaryOp, Column, SortOrder
from .dataset_test_utils import TestDataMaker, enriched_item
from .dataset_utils import lilac_embedding
class TestSignal(TextSignal):
name = 'test_signal'
def fields(self) -> Field:
return field(fields={'len': 'int32', 'is_all_cap': 'boolean'})
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
for text_content in data:
yield {'len': len(text_content), 'is_all_cap': text_content.isupper()}
class TestPrimitiveSignal(TextSignal):
name = 'primitive_signal'
def fields(self) -> Field:
return field('int32')
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
for text_content in data:
yield len(text_content) + 1
class NestedArraySignal(TextSignal):
name = 'nested_array'
def fields(self) -> Field:
return field(fields=[['int32']])
def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
for text_content in data:
yield [[len(text_content) + 1], [len(text_content)]]
@pytest.fixture(scope='module', autouse=True)
def setup_teardown() -> Iterable[None]:
# Setup.
register_signal(TestSignal)
register_signal(TestPrimitiveSignal)
register_signal(NestedArraySignal)
register_signal(TopKEmbedding)
# Unit test runs.
yield
# Teardown.
clear_signal_registry()
def test_sort_by_source_no_alias_no_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'erased': True,
'score': 4.1,
'document': {
'num_pages': 4,
'header': {
'title': 'c'
}
}
}, {
UUID_COLUMN: '2',
'erased': False,
'score': 3.5,
'document': {
'num_pages': 5,
'header': {
'title': 'b'
}
},
}, {
UUID_COLUMN: '3',
'erased': True,
'score': 3.7,
'document': {
'num_pages': 3,
'header': {
'title': 'a'
}
},
}])
# Sort by bool.
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['erased'], sort_order=SortOrder.ASC)
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['erased'], sort_order=SortOrder.DESC)
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}]
# Sort by float.
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['score'], sort_order=SortOrder.ASC)
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}]
result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['score'], sort_order=SortOrder.DESC)
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}]
# Sort by nested int.
result = dataset.select_rows(
columns=[UUID_COLUMN], sort_by=['document.num_pages'], sort_order=SortOrder.ASC)
assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}]
result = dataset.select_rows(
columns=[UUID_COLUMN], sort_by=['document.num_pages'], sort_order=SortOrder.DESC)
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
# Sort by double nested string.
result = dataset.select_rows(
columns=[UUID_COLUMN], sort_by=['document.header.title'], sort_order=SortOrder.ASC)
assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}]
result = dataset.select_rows(
columns=[UUID_COLUMN], sort_by=['document.header.title'], sort_order=SortOrder.DESC)
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}]
def test_sort_by_signal_no_alias_no_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'HEY'
}, {
UUID_COLUMN: '2',
'text': 'everyone'
}, {
UUID_COLUMN: '3',
'text': 'HI'
}])
dataset.compute_signal(TestSignal(), 'text')
# Sort by `signal.len`.
result = dataset.select_rows(
columns=[UUID_COLUMN], sort_by=['text.test_signal.len'], sort_order=SortOrder.ASC)
assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}]
result = dataset.select_rows(
columns=[UUID_COLUMN], sort_by=['text.test_signal.len'], sort_order=SortOrder.DESC)
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
# Sort by `signal.is_all_cap`.
result = dataset.select_rows(
columns=[UUID_COLUMN], sort_by=['text.test_signal.is_all_cap'], sort_order=SortOrder.ASC)
assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
result = dataset.select_rows(
columns=[UUID_COLUMN], sort_by=['text.test_signal.is_all_cap'], sort_order=SortOrder.DESC)
assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}]
def test_sort_by_signal_alias_no_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'HEY'
}, {
UUID_COLUMN: '2',
'text': 'everyone'
}, {
UUID_COLUMN: '3',
'text': 'HI'
}])
dataset.compute_signal(TestSignal(), 'text')
# Sort by `signal.len`.
signal_alias = Column('text.test_signal', alias='signal')
result = dataset.select_rows(
columns=[signal_alias], sort_by=['signal.len'], sort_order=SortOrder.ASC)
assert list(result) == [{
UUID_COLUMN: '3',
'signal': {
'len': 2,
'is_all_cap': True
}
}, {
UUID_COLUMN: '1',
'signal': {
'len': 3,
'is_all_cap': True
}
}, {
UUID_COLUMN: '2',
'signal': {
'len': 8,
'is_all_cap': False
}
}]
result = dataset.select_rows(
columns=[signal_alias], sort_by=['signal.len'], sort_order=SortOrder.DESC)
assert list(result) == [{
UUID_COLUMN: '2',
'signal': {
'len': 8,
'is_all_cap': False
}
}, {
UUID_COLUMN: '1',
'signal': {
'len': 3,
'is_all_cap': True
}
}, {
UUID_COLUMN: '3',
'signal': {
'len': 2,
'is_all_cap': True
}
}]
def test_sort_by_enriched_alias_no_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'HEY'
}, {
UUID_COLUMN: '2',
'text': 'everyone'
}, {
UUID_COLUMN: '3',
'text': 'HI'
}])
dataset.compute_signal(TestSignal(), 'text')
# Sort by `document.test_signal.is_all_cap` where 'document' is an alias to 'text'.
text_alias = Column('text', alias='document')
result = dataset.select_rows(
columns=[text_alias], sort_by=['document.test_signal.is_all_cap'], sort_order=SortOrder.ASC)
assert list(result) == [{
UUID_COLUMN: '2',
'document': enriched_item('everyone', {'test_signal': {
'len': 8,
'is_all_cap': False
}})
}, {
UUID_COLUMN: '1',
'document': enriched_item('HEY', {'test_signal': {
'len': 3,
'is_all_cap': True
}})
}, {
UUID_COLUMN: '3',
'document': enriched_item('HI', {'test_signal': {
'len': 2,
'is_all_cap': True
}})
}]
result = dataset.select_rows(
columns=[text_alias], sort_by=['document.test_signal.is_all_cap'], sort_order=SortOrder.DESC)
assert list(result) == [{
UUID_COLUMN: '1',
'document': enriched_item('HEY', {'test_signal': {
'len': 3,
'is_all_cap': True
}})
}, {
UUID_COLUMN: '3',
'document': enriched_item('HI', {'test_signal': {
'len': 2,
'is_all_cap': True
}})
}, {
UUID_COLUMN: '2',
'document': enriched_item('everyone', {'test_signal': {
'len': 8,
'is_all_cap': False
}})
}]
def test_sort_by_udf_alias_no_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'HEY'
}, {
UUID_COLUMN: '2',
'text': 'everyone'
}, {
UUID_COLUMN: '3',
'text': 'HI'
}])
# Equivalent to: SELECT `TestSignal(text) AS udf`.
text_udf = Column('text', signal_udf=TestSignal(), alias='udf')
# Sort by `udf.len`, where `udf` is an alias to `TestSignal(text)`.
result = dataset.select_rows(['*', text_udf], sort_by=['udf.len'], sort_order=SortOrder.ASC)
assert list(result) == [{
UUID_COLUMN: '3',
'text': 'HI',
'udf': {
'len': 2,
'is_all_cap': True
}
}, {
UUID_COLUMN: '1',
'text': 'HEY',
'udf': {
'len': 3,
'is_all_cap': True
}
}, {
UUID_COLUMN: '2',
'text': 'everyone',
'udf': {
'len': 8,
'is_all_cap': False
}
}]
def test_sort_by_udf_no_alias_no_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'HEY'
}, {
UUID_COLUMN: '2',
'text': 'everyone'
}, {
UUID_COLUMN: '3',
'text': 'HI'
}])
text_udf = Column('text', signal_udf=TestSignal())
# Sort by `text.test_signal.len`, produced by executing the udf `TestSignal(text)`.
result = dataset.select_rows(['*', text_udf],
sort_by=[('text', 'test_signal', 'len')],
sort_order=SortOrder.ASC,
combine_columns=True)
assert list(result) == [{
UUID_COLUMN: '3',
'text': enriched_item('HI', {'test_signal': {
'len': 2,
'is_all_cap': True
}}),
}, {
UUID_COLUMN: '1',
'text': enriched_item('HEY', {'test_signal': {
'len': 3,
'is_all_cap': True
}}),
}, {
UUID_COLUMN: '2',
'text': enriched_item('everyone', {'test_signal': {
'len': 8,
'is_all_cap': False
}}),
}]
# Sort descending.
result = dataset.select_rows(['*', text_udf],
sort_by=[('text', 'test_signal', 'len')],
sort_order=SortOrder.DESC,
combine_columns=True)
assert list(result) == [{
UUID_COLUMN: '2',
'text': enriched_item('everyone', {'test_signal': {
'len': 8,
'is_all_cap': False
}}),
}, {
UUID_COLUMN: '1',
'text': enriched_item('HEY', {'test_signal': {
'len': 3,
'is_all_cap': True
}}),
}, {
UUID_COLUMN: '3',
'text': enriched_item('HI', {'test_signal': {
'len': 2,
'is_all_cap': True
}}),
}]
def test_sort_by_primitive_udf_alias_no_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'HEY'
}, {
UUID_COLUMN: '2',
'text': 'everyone'
}, {
UUID_COLUMN: '3',
'text': 'HI'
}])
# Equivalent to: SELECT `TestPrimitiveSignal(text) AS udf`.
text_udf = Column('text', signal_udf=TestPrimitiveSignal(), alias='udf')
# Sort by the primitive value returned by the udf.
result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.ASC)
assert list(result) == [{
UUID_COLUMN: '3',
'text': 'HI',
'udf': 3
}, {
UUID_COLUMN: '1',
'text': 'HEY',
'udf': 4
}, {
UUID_COLUMN: '2',
'text': 'everyone',
'udf': 9
}]
def test_sort_by_source_non_leaf_errors(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'vals': [7, 1]
}, {
UUID_COLUMN: '2',
'vals': [3, 4]
}, {
UUID_COLUMN: '3',
'vals': [9, 0]
}])
# Sort by repeated.
with pytest.raises(ValueError, match='Unable to sort by path'):
dataset.select_rows(columns=[UUID_COLUMN], sort_by=['vals'], sort_order=SortOrder.ASC)
def test_sort_by_source_no_alias_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'vals': [[{
'score': 7
}, {
'score': 1
}], [{
'score': 1
}, {
'score': 7
}]]
}, {
UUID_COLUMN: '2',
'vals': [[{
'score': 3
}, {
'score': 4
}]]
}, {
UUID_COLUMN: '3',
'vals': [[{
'score': 9
}, {
'score': 0
}]]
}])
# Sort by repeated 'vals'.
result = dataset.select_rows(
columns=[UUID_COLUMN, 'vals'], sort_by=['vals.*.*.score'], sort_order=SortOrder.ASC)
assert list(result) == [{
UUID_COLUMN: '3',
'vals': [[{
'score': 9
}, {
'score': 0
}]]
}, {
UUID_COLUMN: '1',
'vals': [[{
'score': 7
}, {
'score': 1
}], [{
'score': 1
}, {
'score': 7
}]]
}, {
UUID_COLUMN: '2',
'vals': [[{
'score': 3
}, {
'score': 4
}]]
}]
result = dataset.select_rows(
columns=[UUID_COLUMN, 'vals'], sort_by=['vals.*.*.score'], sort_order=SortOrder.DESC)
assert list(result) == [{
UUID_COLUMN: '3',
'vals': [[{
'score': 9
}, {
'score': 0
}]]
}, {
UUID_COLUMN: '1',
'vals': [[{
'score': 7
}, {
'score': 1
}], [{
'score': 1
}, {
'score': 7
}]]
}, {
UUID_COLUMN: '2',
'vals': [[{
'score': 3
}, {
'score': 4
}]]
}]
def test_sort_by_source_alias_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'vals': [[7, 1], [1, 7]]
}, {
UUID_COLUMN: '2',
'vals': [[3], [11]]
}, {
UUID_COLUMN: '3',
'vals': [[9, 0]]
}])
# Sort by repeated 'vals'.
result = dataset.select_rows(
columns=[UUID_COLUMN, Column('vals', alias='scores')],
sort_by=['scores.*.*'],
sort_order=SortOrder.ASC)
assert list(result) == [{
UUID_COLUMN: '3',
'scores': [[9, 0]]
}, {
UUID_COLUMN: '1',
'scores': [[7, 1], [1, 7]]
}, {
UUID_COLUMN: '2',
'scores': [[3], [11]]
}]
result = dataset.select_rows(
columns=[UUID_COLUMN, Column('vals', alias='scores')],
sort_by=['scores.*.*'],
sort_order=SortOrder.DESC)
assert list(result) == [{
UUID_COLUMN: '2',
'scores': [[3], [11]]
}, {
UUID_COLUMN: '3',
'scores': [[9, 0]]
}, {
UUID_COLUMN: '1',
'scores': [[7, 1], [1, 7]]
}]
def test_sort_by_udf_alias_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'text': 'HEY'
}, {
UUID_COLUMN: '2',
'text': 'everyone'
}, {
UUID_COLUMN: '3',
'text': 'HI'
}])
# Equivalent to: SELECT `NestedArraySignal(text) AS udf`.
text_udf = Column('text', signal_udf=NestedArraySignal(), alias='udf')
# Sort by `udf.*.*`, where `udf` is an alias to `NestedArraySignal(text)`.
result = dataset.select_rows(['*', text_udf], sort_by=['udf.*.*'], sort_order=SortOrder.ASC)
assert list(result) == [{
UUID_COLUMN: '3',
'text': 'HI',
'udf': [[3], [2]]
}, {
UUID_COLUMN: '1',
'text': 'HEY',
'udf': [[4], [3]]
}, {
UUID_COLUMN: '2',
'text': 'everyone',
'udf': [[9], [8]]
}]
result = dataset.select_rows(['*', text_udf], sort_by=['udf.*.*'], sort_order=SortOrder.DESC)
assert list(result) == [{
UUID_COLUMN: '2',
'text': 'everyone',
'udf': [[9], [8]]
}, {
UUID_COLUMN: '1',
'text': 'HEY',
'udf': [[4], [3]]
}, {
UUID_COLUMN: '3',
'text': 'HI',
'udf': [[3], [2]]
}]
def test_sort_by_complex_signal_udf_alias_called_on_repeated(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'texts': [{
'text': 'eardrop'
}, {
'text': 'I'
}]
}, {
UUID_COLUMN: '2',
'texts': [{
'text': 'hey'
}, {
'text': 'CARS'
}]
}, {
UUID_COLUMN: '3',
'texts': [{
'text': 'everyone'
}, {
'text': ''
}]
}])
# Equivalent to: SELECT `TestSignal(texts.*.text) AS udf`.
texts_udf = Column('texts.*.text', signal_udf=TestSignal(), alias='udf')
# Sort by `udf.len`, where `udf` is an alias to `TestSignal(texts.*.text)`.
result = dataset.select_rows(['*', texts_udf],
sort_by=['udf.len'],
sort_order=SortOrder.ASC,
combine_columns=True)
assert list(result) == [{
UUID_COLUMN: '3',
'texts': [{
'text': enriched_item('everyone', {'test_signal': {
'len': 8,
'is_all_cap': False
}})
}, {
'text': enriched_item('', {'test_signal': {
'len': 0,
'is_all_cap': False
}})
}]
}, {
UUID_COLUMN: '1',
'texts': [{
'text': enriched_item('eardrop', {'test_signal': {
'len': 7,
'is_all_cap': False
}})
}, {
'text': enriched_item('I', {'test_signal': {
'len': 1,
'is_all_cap': True
}})
}]
}, {
UUID_COLUMN: '2',
'texts': [{
'text': enriched_item('hey', {'test_signal': {
'len': 3,
'is_all_cap': False
}})
}, {
'text': enriched_item('CARS', {'test_signal': {
'len': 4,
'is_all_cap': True
}})
}]
}]
def test_sort_by_primitive_signal_udf_alias_called_on_repeated(
make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'texts': [{
'text': 'eardrop'
}, {
'text': 'I'
}]
}, {
UUID_COLUMN: '2',
'texts': [{
'text': 'hey'
}, {
'text': 'CARS'
}]
}, {
UUID_COLUMN: '3',
'texts': [{
'text': 'everyone'
}, {
'text': ''
}]
}])
# Equivalent to: SELECT `TestPrimitiveSignal(texts.*.text) AS udf`.
texts_udf = Column('texts.*.text', signal_udf=TestPrimitiveSignal(), alias='udf')
# Sort by `udf`, where `udf` is an alias to `TestPrimitiveSignal(texts.*.text)`.
result = dataset.select_rows(['*', texts_udf],
sort_by=['udf'],
sort_order=SortOrder.ASC,
combine_columns=True)
assert list(result) == [{
UUID_COLUMN: '3',
'texts': [{
'text': enriched_item('everyone', {'primitive_signal': 9})
}, {
'text': enriched_item('', {'primitive_signal': 1})
}]
}, {
UUID_COLUMN: '1',
'texts': [{
'text': enriched_item('eardrop', {'primitive_signal': 8})
}, {
'text': enriched_item('I', {'primitive_signal': 2})
}]
}, {
UUID_COLUMN: '2',
'texts': [{
'text': enriched_item('hey', {'primitive_signal': 4})
}, {
'text': enriched_item('CARS', {'primitive_signal': 5})
}]
}]
result = dataset.select_rows(['*', texts_udf],
sort_by=['udf'],
sort_order=SortOrder.DESC,
combine_columns=True)
assert list(result) == [{
UUID_COLUMN: '3',
'texts': [{
'text': enriched_item('everyone', {'primitive_signal': 9})
}, {
'text': enriched_item('', {'primitive_signal': 1})
}]
}, {
UUID_COLUMN: '1',
'texts': [{
'text': enriched_item('eardrop', {'primitive_signal': 8})
}, {
'text': enriched_item('I', {'primitive_signal': 2})
}]
}, {
UUID_COLUMN: '2',
'texts': [{
'text': enriched_item('hey', {'primitive_signal': 4})
}, {
'text': enriched_item('CARS', {'primitive_signal': 5})
}]
}]
class TopKEmbedding(TextEmbeddingSignal):
"""A test embed function."""
name = 'topk_embedding'
def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
"""Call the embedding function."""
for example in data:
example = cast(str, example)
emb_spans: list[Item] = []
for i, score in enumerate(example.split('_')):
start, end = i * 2, i * 2 + 1
vector = np.array([int(score)])
emb_spans.append(lilac_embedding(start, end, vector))
yield emb_spans
class TopKSignal(TextEmbeddingModelSignal):
"""Compute scores along a given concept for documents."""
name = 'topk_signal'
_query = np.array([1])
def fields(self) -> Field:
return field('float32')
@override
def vector_compute(self, keys: Iterable[VectorKey],
vector_store: VectorStore) -> Iterable[Optional[Item]]:
text_embeddings = vector_store.get(keys)
dot_products = text_embeddings.dot(self._query).reshape(-1)
return dot_products.tolist()
@override
def vector_compute_topk(
self,
topk: int,
vector_store: VectorStore,
keys: Optional[Iterable[VectorKey]] = None) -> Sequence[tuple[VectorKey, Optional[Item]]]:
return vector_store.topk(self._query, topk, keys)
def test_sort_by_topk_embedding_udf(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'scores': '8_1',
}, {
UUID_COLUMN: '2',
'scores': '3_5'
}, {
UUID_COLUMN: '3',
'scores': '9_7'
}])
dataset.compute_signal(TopKEmbedding(), 'scores')
# Equivalent to: SELECT `TopKSignal(scores, embedding='...') AS udf`.
text_udf = Column('scores', signal_udf=TopKSignal(embedding='topk_embedding'), alias='udf')
# Sort by `udf`, where `udf` is an alias to `TopKSignal(scores, embedding='...')`.
result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.DESC, limit=3)
assert list(result) == [{
UUID_COLUMN: '3',
'scores': enriched_item(
'9_7', {'topk_embedding': [lilac_embedding(0, 1, None),
lilac_embedding(2, 3, None)]}),
'udf': [9.0, 7.0]
}, {
UUID_COLUMN: '1',
'scores': enriched_item(
'8_1', {'topk_embedding': [lilac_embedding(0, 1, None),
lilac_embedding(2, 3, None)]}),
'udf': [8.0, 1.0]
}]
# Same but set limit to 4.
result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.DESC, limit=4)
assert list(result) == [{
UUID_COLUMN: '3',
'scores': enriched_item(
'9_7', {'topk_embedding': [lilac_embedding(0, 1, None),
lilac_embedding(2, 3, None)]}),
'udf': [9.0, 7.0]
}, {
UUID_COLUMN: '1',
'scores': enriched_item(
'8_1', {'topk_embedding': [lilac_embedding(0, 1, None),
lilac_embedding(2, 3, None)]}),
'udf': [8.0, 1.0]
}, {
UUID_COLUMN: '2',
'scores': enriched_item(
'3_5', {'topk_embedding': [lilac_embedding(0, 1, None),
lilac_embedding(2, 3, None)]}),
'udf': [3.0, 5.0]
}]
def test_sort_by_topk_udf_with_filter(make_test_data: TestDataMaker) -> None:
dataset = make_test_data([{
UUID_COLUMN: '1',
'scores': '8_1',
'active': True
}, {
UUID_COLUMN: '2',
'scores': '3_5',
'active': True
}, {
UUID_COLUMN: '3',
'scores': '9_7',
'active': False
}])
dataset.compute_signal(TopKEmbedding(), 'scores')
# Equivalent to: SELECT `TopKSignal(scores, embedding='...') AS udf`.
text_udf = Column('scores', signal_udf=TopKSignal(embedding='topk_embedding'), alias='udf')
# Sort by `udf`, where `udf` is an alias to `TopKSignal(scores, embedding='...')`.
result = dataset.select_rows(['*', text_udf],
sort_by=['udf'],
filters=[('active', BinaryOp.EQUALS, True)],
sort_order=SortOrder.DESC,
limit=2)
# We make sure that '3' is not in the result, because it is not active, even though it has the
# highest topk score.
assert list(result) == [{
UUID_COLUMN: '1',
'active': True,
'scores': enriched_item(
'8_1', {'topk_embedding': [lilac_embedding(0, 1, None),
lilac_embedding(2, 3, None)]}),
'udf': [8.0, 1.0]
}, {
UUID_COLUMN: '2',
'active': True,
'scores': enriched_item(
'3_5', {'topk_embedding': [lilac_embedding(0, 1, None),
lilac_embedding(2, 3, None)]}),
'udf': [3.0, 5.0]
}]