Spaces:
Runtime error
Runtime error
"""Test for the concept label signal.""" | |
import pathlib | |
from typing import Generator, Type | |
import pytest | |
from pytest_mock import MockerFixture | |
from ..concepts.concept import ExampleIn | |
from ..concepts.db_concept import ConceptDB, ConceptUpdate, DiskConceptDB, DiskConceptModelDB | |
from ..config import CONFIG | |
from ..data.dataset_duckdb import DatasetDuckDB | |
from ..data.dataset_utils import lilac_span | |
from ..db_manager import set_default_dataset_cls | |
from ..schema import SignalInputType | |
from .concept_labels import ConceptLabelsSignal | |
from .signal import clear_signal_registry | |
ALL_CONCEPT_DBS = [DiskConceptDB] | |
ALL_CONCEPT_MODEL_DBS = [DiskConceptModelDB] | |
def set_data_path(tmp_path: pathlib.Path, mocker: MockerFixture) -> None: | |
mocker.patch.dict(CONFIG, {'LILAC_DATA_PATH': str(tmp_path)}) | |
def setup_teardown() -> Generator: | |
# Setup. | |
set_default_dataset_cls(DatasetDuckDB) | |
# Unit test runs. | |
yield | |
# Teardown. | |
clear_signal_registry() | |
def test_concept_does_not_exist() -> None: | |
signal = ConceptLabelsSignal(namespace='test', concept_name='concept_doesnt_exist') | |
with pytest.raises(ValueError, match='Concept "test/concept_doesnt_exist" does not exist'): | |
list(signal.compute(['a new data point', 'not in concept'])) | |
def test_concept_labels(concept_db_cls: Type[ConceptDB]) -> None: | |
concept_db = concept_db_cls() | |
namespace = 'test' | |
concept_name = 'test_concept' | |
concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) | |
train_data = [ | |
ExampleIn(label=False, text='no in concept'), | |
ExampleIn(label=True, text='yes in concept'), | |
# This should never show since we request the main draft. | |
ExampleIn(label=False, text='this is unrelated', draft='test_draft') | |
] | |
concept_db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) | |
signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept') | |
results = list( | |
signal.compute([ | |
'this is no in concept', 'this is yes in concept', | |
'this is no in concept. filler. this is yes in concept.', 'this is unrelated' | |
])) | |
assert results == [ | |
[lilac_span(8, 8 + len('no in concept'), {'label': False})], | |
[lilac_span(8, 8 + len('yes in concept'), {'label': True})], | |
[ | |
lilac_span(8, 8 + len('no in concept'), {'label': False}), | |
lilac_span(39, 39 + len('yes in concept'), {'label': True}) | |
], | |
# This example is in the draft, which was not requested. | |
None | |
] | |
def test_concept_labels_draft(concept_db_cls: Type[ConceptDB]) -> None: | |
concept_db = concept_db_cls() | |
namespace = 'test' | |
concept_name = 'test_concept' | |
concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) | |
concept_update = ConceptUpdate(insert=[ | |
ExampleIn(label=True, text='in concept'), | |
ExampleIn(label=False, text='out of concept'), | |
ExampleIn(label=True, text='in draft', draft='test_draft'), | |
ExampleIn(label=False, text='out draft', draft='test_draft') | |
]) | |
concept_db.edit(namespace, concept_name, concept_update) | |
signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept', draft='test_draft') | |
results = list(signal.compute(['this is in concept', 'this is in draft', 'this is out draft'])) | |
assert results == [[lilac_span(8, 8 + len('in concept'), {'label': True})], | |
[lilac_span(8, 8 + len('in draft'), { | |
'label': True, | |
'draft': 'test_draft' | |
})], | |
[lilac_span(8, 8 + len('out draft'), { | |
'label': False, | |
'draft': 'test_draft' | |
})]] | |
def test_concept_labels_key(concept_db_cls: Type[ConceptDB]) -> None: | |
concept_db = concept_db_cls() | |
namespace = 'test' | |
concept_name = 'test_concept' | |
concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) | |
signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept') | |
assert signal.key() == 'test/test_concept/labels' | |
def test_concept_labels_compute_signal_key(concept_db_cls: Type[ConceptDB]) -> None: | |
concept_db = concept_db_cls() | |
namespace = 'test' | |
concept_name = 'test_concept' | |
concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) | |
signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept') | |
assert signal.key(is_computed_signal=True) == 'test/test_concept/labels/v0' | |