"""Test the public REST API for concepts.""" import uuid from pathlib import Path from typing import Iterable, cast import numpy as np import pytest from fastapi.testclient import TestClient from pydantic import parse_obj_as from pytest_mock import MockerFixture from typing_extensions import override from .concepts.concept import ( DRAFT_MAIN, Concept, Example, ExampleIn, ExampleOrigin, LogisticEmbeddingModel, ) from .concepts.db_concept import ConceptInfo, ConceptUpdate from .config import CONFIG from .data.dataset_utils import lilac_embedding from .router_concept import ( ConceptModelInfo, CreateConceptOptions, MergeConceptDraftOptions, ScoreBody, ScoreExample, ScoreResponse, ) from .schema import Item, RichData, SignalInputType from .server import app from .signals.signal import TextEmbeddingSignal, clear_signal_registry, register_signal from .test_utils import fake_uuid client = TestClient(app) EMBEDDINGS: list[tuple[str, list[float]]] = [('hello', [1.0, 0.0, 0.0]), ('hello2', [1.0, 1.0, 0.0]), ('hello world', [1.0, 1.0, 1.0]), ('hello world2', [2.0, 1.0, 1.0])] STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} class TestEmbedding(TextEmbeddingSignal): """A test embed function.""" name = 'test_embedding' @override def compute(self, data: Iterable[RichData]) -> Iterable[Item]: """Call the embedding function.""" for example in data: yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))] @pytest.fixture(scope='module', autouse=True) def setup_teardown() -> Iterable[None]: # Setup. register_signal(TestEmbedding) # Unit test runs. yield # Teardown. clear_signal_registry() @pytest.fixture(scope='function', autouse=True) def setup_data_dir(tmp_path: Path, mocker: MockerFixture) -> None: mocker.patch.dict(CONFIG, {'LILAC_DATA_PATH': str(tmp_path)}) def test_concept_create() -> None: url = '/api/v1/concepts/' response = client.get(url) assert response.status_code == 200 assert response.json() == [] # Create a concept. url = '/api/v1/concepts/create' create_concept = CreateConceptOptions( namespace='concept_namespace', name='concept', type=SignalInputType.TEXT) response = client.post(url, json=create_concept.dict()) assert response.status_code == 200 assert Concept.parse_obj(response.json()) == Concept( namespace='concept_namespace', concept_name='concept', type=SignalInputType.TEXT, data={}, version=0) # Make sure list shows us the new concept. url = '/api/v1/concepts/' response = client.get(url) assert response.status_code == 200 assert parse_obj_as(list[ConceptInfo], response.json()) == [ ConceptInfo( namespace='concept_namespace', name='concept', type=SignalInputType.TEXT, drafts=[DRAFT_MAIN]) ] def test_concept_edits(mocker: MockerFixture) -> None: mock_uuid = mocker.patch.object(uuid, 'uuid4', autospec=True) # Create the concept. response = client.post( '/api/v1/concepts/create', json=CreateConceptOptions( namespace='concept_namespace', name='concept', type=SignalInputType.TEXT).dict()) # Make sure we can add an example. mock_uuid.return_value = fake_uuid(b'1') url = '/api/v1/concepts/concept_namespace/concept' concept_update = ConceptUpdate(insert=[ ExampleIn( label=True, text='hello', origin=ExampleOrigin( dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')) ]) response = client.post(url, json=concept_update.dict()) assert response.status_code == 200 assert Concept.parse_obj(response.json()) == Concept( namespace='concept_namespace', concept_name='concept', type='text', data={ fake_uuid(b'1').hex: Example( id=fake_uuid(b'1').hex, label=True, text='hello', origin=ExampleOrigin( dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')) }, version=1) url = '/api/v1/concepts/' response = client.get(url) assert response.status_code == 200 assert parse_obj_as(list[ConceptInfo], response.json()) == [ ConceptInfo( namespace='concept_namespace', name='concept', type=SignalInputType.TEXT, drafts=[DRAFT_MAIN]) ] # Add another example. mock_uuid.return_value = fake_uuid(b'2') url = '/api/v1/concepts/concept_namespace/concept' concept_update = ConceptUpdate(insert=[ ExampleIn( label=True, text='hello2', origin=ExampleOrigin( dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d2')) ]) response = client.post(url, json=concept_update.dict()) assert response.status_code == 200 assert Concept.parse_obj(response.json()) == Concept( namespace='concept_namespace', concept_name='concept', type='text', data={ fake_uuid(b'1').hex: Example( id=fake_uuid(b'1').hex, label=True, text='hello', origin=ExampleOrigin( dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')), fake_uuid(b'2').hex: Example( id=fake_uuid(b'2').hex, label=True, text='hello2', origin=ExampleOrigin( dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d2')) }, version=2) # Edit both examples. url = '/api/v1/concepts/concept_namespace/concept' concept_update = ConceptUpdate(update=[ # Switch the label. Example(id=fake_uuid(b'1').hex, label=False, text='hello'), # Switch the text. Example(id=fake_uuid(b'2').hex, label=True, text='hello world'), ]) response = client.post(url, json=concept_update.dict()) assert response.status_code == 200 assert Concept.parse_obj(response.json()) == Concept( namespace='concept_namespace', concept_name='concept', type='text', data={ fake_uuid(b'1').hex: Example(id=fake_uuid(b'1').hex, label=False, text='hello'), fake_uuid(b'2').hex: Example(id=fake_uuid(b'2').hex, label=True, text='hello world') }, version=3) # Delete the first example. url = '/api/v1/concepts/concept_namespace/concept' concept_update = ConceptUpdate(remove=[fake_uuid(b'1').hex]) response = client.post(url, json=concept_update.dict()) assert response.status_code == 200 assert Concept.parse_obj(response.json()) == Concept( namespace='concept_namespace', concept_name='concept', type='text', data={fake_uuid(b'2').hex: Example(id=fake_uuid(b'2').hex, label=True, text='hello world')}, version=4) # The concept still exists. url = '/api/v1/concepts/' response = client.get(url) assert response.status_code == 200 assert parse_obj_as(list[ConceptInfo], response.json()) == [ ConceptInfo( namespace='concept_namespace', name='concept', type=SignalInputType.TEXT, drafts=[DRAFT_MAIN]) ] def test_concept_drafts(mocker: MockerFixture) -> None: mock_uuid = mocker.patch.object(uuid, 'uuid4', autospec=True) # Create the concept. response = client.post( '/api/v1/concepts/create', json=CreateConceptOptions( namespace='concept_namespace', name='concept', type=SignalInputType.TEXT).dict()) # Add examples, some drafts. mock_uuid.side_effect = [fake_uuid(b'1'), fake_uuid(b'2'), fake_uuid(b'3'), fake_uuid(b'4')] url = '/api/v1/concepts/concept_namespace/concept' concept_update = ConceptUpdate(insert=[ ExampleIn(label=True, text='in concept'), ExampleIn(label=False, text='out of concept'), ExampleIn(label=False, text='in concept', draft='test_draft'), ExampleIn(label=False, text='out of concept draft', draft='test_draft') ]) response = client.post(url, json=concept_update.dict()) assert response.status_code == 200 # Make sure list shows us the drafts url = '/api/v1/concepts/' response = client.get(url) assert response.status_code == 200 assert parse_obj_as(list[ConceptInfo], response.json()) == [ ConceptInfo( namespace='concept_namespace', name='concept', type=SignalInputType.TEXT, drafts=[DRAFT_MAIN, 'test_draft']) ] # Make sure when we request main, we only get data in main. url = '/api/v1/concepts/concept_namespace/concept' response = client.get(url) assert response.status_code == 200 assert Concept.parse_obj(response.json()) == Concept( namespace='concept_namespace', concept_name='concept', type='text', data={ # Only main are returned. fake_uuid(b'1').hex: Example(id=fake_uuid(b'1').hex, label=True, text='in concept'), fake_uuid(b'2').hex: Example(id=fake_uuid(b'2').hex, label=False, text='out of concept') }, version=1) # Make sure when we request the draft, we get the draft data deduped with main. url = '/api/v1/concepts/concept_namespace/concept?draft=test_draft' response = client.get(url) assert response.status_code == 200 assert Concept.parse_obj(response.json()) == Concept( namespace='concept_namespace', concept_name='concept', type='text', data={ # b'1' is deduped with b'3'. fake_uuid(b'2').hex: Example(id=fake_uuid(b'2').hex, label=False, text='out of concept'), # ID 3 is a duplicate of main's 1. fake_uuid(b'3').hex: Example( id=fake_uuid(b'3').hex, label=False, text='in concept', draft='test_draft'), fake_uuid(b'4').hex: Example( id=fake_uuid(b'4').hex, label=False, text='out of concept draft', draft='test_draft') }, version=1) # Merge the draft. response = client.post( '/api/v1/concepts/concept_namespace/concept/merge_draft', json=MergeConceptDraftOptions(draft='test_draft').dict()) assert response.status_code == 200 # Make sure we get the merged drafts. url = '/api/v1/concepts/concept_namespace/concept' response = client.get(url) assert response.status_code == 200 assert Concept.parse_obj(response.json()).dict() == Concept( namespace='concept_namespace', concept_name='concept', type='text', data={ # b'1' is deduped with b'3'. fake_uuid(b'2').hex: Example(id=fake_uuid(b'2').hex, label=False, text='out of concept'), # ID 3 is a duplicate of main's 1. fake_uuid(b'3').hex: Example(id=fake_uuid(b'3').hex, label=False, text='in concept'), fake_uuid(b'4').hex: Example( id=fake_uuid(b'4').hex, label=False, text='out of concept draft') }, version=2).dict() def test_concept_model_sync(mocker: MockerFixture) -> None: mock_uuid = mocker.patch.object(uuid, 'uuid4', autospec=True) # Create the concept. response = client.post( '/api/v1/concepts/create', json=CreateConceptOptions( namespace='concept_namespace', name='concept', type=SignalInputType.TEXT).dict()) # Add two examples. mock_uuid.side_effect = [fake_uuid(b'1'), fake_uuid(b'2')] url = '/api/v1/concepts/concept_namespace/concept' concept_update = ConceptUpdate(insert=[ ExampleIn( label=True, text='hello', origin=ExampleOrigin( dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')), ExampleIn( label=False, text='hello world', origin=ExampleOrigin( dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d2')) ]) response = client.post(url, json=concept_update.dict()) assert response.status_code == 200 # Get the concept model. url = '/api/v1/concepts/concept_namespace/concept/model/test_embedding' response = client.get(url) assert response.status_code == 200 assert ConceptModelInfo.parse_obj(response.json()) == ConceptModelInfo( namespace='concept_namespace', concept_name='concept', embedding_name='test_embedding', version=1) # Score an example. mock_score_emb = mocker.patch.object(LogisticEmbeddingModel, 'score_embeddings', autospec=True) mock_score_emb.return_value = np.array([0.9, 1.0]) url = '/api/v1/concepts/concept_namespace/concept/model/test_embedding/score' score_body = ScoreBody(examples=[ScoreExample(text='hello world'), ScoreExample(text='hello')]) response = client.post(url, json=score_body.dict()) assert response.status_code == 200 assert ScoreResponse.parse_obj(response.json()) == ScoreResponse( scores=[0.9, 1.0], # The model should already be synced. model_synced=False) def test_concept_edits_error_before_create(mocker: MockerFixture) -> None: url = '/api/v1/concepts/concept_namespace/concept' concept_update = ConceptUpdate(insert=[ ExampleIn( label=True, text='hello', origin=ExampleOrigin( dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')) ]) response = client.post(url, json=concept_update.dict()) assert response.is_error is True assert response.status_code == 500 def test_concept_edits_wrong_type(mocker: MockerFixture) -> None: # Create the concept. response = client.post( '/api/v1/concepts/create', json=CreateConceptOptions( namespace='concept_namespace', name='concept', type=SignalInputType.IMAGE).dict()) url = '/api/v1/concepts/concept_namespace/concept' concept_update = ConceptUpdate(insert=[ ExampleIn( label=True, text='hello', origin=ExampleOrigin( dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')) ]) response = client.post(url, json=concept_update.dict()) assert response.is_error is True assert response.status_code == 500