Spaces:
Runtime error
Runtime error
File size: 4,759 Bytes
e4f9cbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""Test for the concept label signal."""
import pathlib
from typing import Generator, Type
import pytest
from pytest_mock import MockerFixture
from ..concepts.concept import ExampleIn
from ..concepts.db_concept import ConceptDB, ConceptUpdate, DiskConceptDB, DiskConceptModelDB
from ..config import CONFIG
from ..data.dataset_duckdb import DatasetDuckDB
from ..data.dataset_utils import lilac_span
from ..db_manager import set_default_dataset_cls
from ..schema import SignalInputType
from .concept_labels import ConceptLabelsSignal
from .signal import clear_signal_registry
ALL_CONCEPT_DBS = [DiskConceptDB]
ALL_CONCEPT_MODEL_DBS = [DiskConceptModelDB]
@pytest.fixture(autouse=True)
def set_data_path(tmp_path: pathlib.Path, mocker: MockerFixture) -> None:
mocker.patch.dict(CONFIG, {'LILAC_DATA_PATH': str(tmp_path)})
@pytest.fixture(scope='module', autouse=True)
def setup_teardown() -> Generator:
# Setup.
set_default_dataset_cls(DatasetDuckDB)
# Unit test runs.
yield
# Teardown.
clear_signal_registry()
def test_concept_does_not_exist() -> None:
signal = ConceptLabelsSignal(namespace='test', concept_name='concept_doesnt_exist')
with pytest.raises(ValueError, match='Concept "test/concept_doesnt_exist" does not exist'):
list(signal.compute(['a new data point', 'not in concept']))
@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS)
def test_concept_labels(concept_db_cls: Type[ConceptDB]) -> None:
concept_db = concept_db_cls()
namespace = 'test'
concept_name = 'test_concept'
concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
train_data = [
ExampleIn(label=False, text='no in concept'),
ExampleIn(label=True, text='yes in concept'),
# This should never show since we request the main draft.
ExampleIn(label=False, text='this is unrelated', draft='test_draft')
]
concept_db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept')
results = list(
signal.compute([
'this is no in concept', 'this is yes in concept',
'this is no in concept. filler. this is yes in concept.', 'this is unrelated'
]))
assert results == [
[lilac_span(8, 8 + len('no in concept'), {'label': False})],
[lilac_span(8, 8 + len('yes in concept'), {'label': True})],
[
lilac_span(8, 8 + len('no in concept'), {'label': False}),
lilac_span(39, 39 + len('yes in concept'), {'label': True})
],
# This example is in the draft, which was not requested.
None
]
@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS)
def test_concept_labels_draft(concept_db_cls: Type[ConceptDB]) -> None:
concept_db = concept_db_cls()
namespace = 'test'
concept_name = 'test_concept'
concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
concept_update = ConceptUpdate(insert=[
ExampleIn(label=True, text='in concept'),
ExampleIn(label=False, text='out of concept'),
ExampleIn(label=True, text='in draft', draft='test_draft'),
ExampleIn(label=False, text='out draft', draft='test_draft')
])
concept_db.edit(namespace, concept_name, concept_update)
signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept', draft='test_draft')
results = list(signal.compute(['this is in concept', 'this is in draft', 'this is out draft']))
assert results == [[lilac_span(8, 8 + len('in concept'), {'label': True})],
[lilac_span(8, 8 + len('in draft'), {
'label': True,
'draft': 'test_draft'
})],
[lilac_span(8, 8 + len('out draft'), {
'label': False,
'draft': 'test_draft'
})]]
@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS)
def test_concept_labels_key(concept_db_cls: Type[ConceptDB]) -> None:
concept_db = concept_db_cls()
namespace = 'test'
concept_name = 'test_concept'
concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept')
assert signal.key() == 'test/test_concept/labels'
@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS)
def test_concept_labels_compute_signal_key(concept_db_cls: Type[ConceptDB]) -> None:
concept_db = concept_db_cls()
namespace = 'test'
concept_name = 'test_concept'
concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept')
assert signal.key(is_computed_signal=True) == 'test/test_concept/labels/v0'
|