diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..6a4638bd03cac95a22b438eb6e9f7974cf0ae5be --- /dev/null +++ b/.dockerignore @@ -0,0 +1,10 @@ +# Python +**/__pycache__ +**/*.pyc +**/*.pyo +**/*.pyd +# Ignore unit tests. +**/*_test.py + +# Mac OS. +.DS_Store diff --git a/.env b/.env new file mode 100644 index 0000000000000000000000000000000000000000..ac30d08073bacc96ea00a099b432b065a02130c1 --- /dev/null +++ b/.env @@ -0,0 +1,40 @@ +# To overwrite these variables, create a .env.local file + +# The path to the directory where the data will be downloaded on machine +LILAC_DATA_PATH=./data + +# Set to 1 for duckdb to use views instead of materialized tables (lower memory usage, but slower). +DUCKDB_USE_VIEWS=0 + +# Set to true to enable read-only mode, disabling the ability to add datasets & compute dataset +# signals. +# LILAC_AUTH_ENABLED=true + +# Variables that can be set in .env.local +# +# Get key from https://dashboard.cohere.ai/api-keys +# COHERE_API_KEY= + +# GCS_REGION= +# GCS_ACCESS_KEY= +# GCS_SECRET_KEY= + +# Get key from https://platform.openai.com/account/api-keys +# OPENAI_API_KEY= +# Get key from https://makersuite.google.com/app/apikey +# PALM_API_KEY= + +# HuggingFace demos: machine that uploads to HuggingFace. + +# For authenticating with HuggingFace to deploy to a Space. +# HF_USERNAME= +# The default repo to deploy to for a staging demo. Can be overridden by a command line flag. +# HF_STAGING_DEMO_REPO='HF_ORG/HF_REPO_NAME' + +# For Google-login. This is generated from the Google Cloud Console for a web client. +# See: https://developers.google.com/identity/protocols/oauth2 +GOOGLE_CLIENT_ID='279475920249-i8llm8vbos1vj5m1qocir8narb3r0enu.apps.googleusercontent.com' +# The client secret of the above client. +# GOOGLE_CLIENT_SECRET= +# A random string for oauth sessions. +# LILAC_OAUTH_SECRET_KEY= diff --git a/.env.demo b/.env.demo new file mode 100644 index 0000000000000000000000000000000000000000..9ca75bdb691b4876cd6bb9e06353a3439f4ef23e --- /dev/null +++ b/.env.demo @@ -0,0 +1,4 @@ +LILAC_DATA_PATH='/data' +HF_HOME=/data/.huggingface +TRANSFORMERS_CACHE=/data/.cache +XDG_CACHE_HOME=/data/.cache diff --git a/.gitattributes b/.gitattributes deleted file mode 100644 index b1ed9f7644c8ed12e23213c88e0ac7bec1795db0..0000000000000000000000000000000000000000 --- a/.gitattributes +++ /dev/null @@ -1,3 +0,0 @@ -dist/lilac-0.0.13-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text -data/.cache/lilac/concept/lilac/profanity/gte-small.pkl filter=lfs diff=lfs merge=lfs -text -data/.cache/lilac/concept/lilac/toxicity/gte-small.pkl filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..148915ec8621238411071d49d06fb9e6d7efc04a --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +**/*.pyc +**/*.pyo +**/*.pyd +**/*_test.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..14188a408fd8edec86a7646bfc7fd26368d80ab2 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,39 @@ +# NOTE: When we upgrade to 3.11 we can use a slimmer docker image which comes with gcc. +FROM python:3.9-bullseye + +# Allow statements and log messages to immediately appear in the Knative logs +ENV PYTHONUNBUFFERED True + +# See: https://huggingface.co/docs/hub/spaces-sdks-docker#permissions +RUN useradd -m -u 1000 user +USER user +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +# Set the working directory in the container. +WORKDIR $HOME/app + +# Install the dependencies. This requires exporting requirements.txt from poetry first, which +# happens from ./build_docker.sh. +COPY --chown=user /dist ./dist/ +RUN pip install --no-index --find-links=dist lilac + +#COPY --chown=user requirements.txt . +#RUN pip install --no-cache-dir -r requirements.txt + +COPY --chown=user .env . +COPY --chown=user .env.demo . +# Copy the README so we can read the datasets from the HuggingFace config. +COPY --chown=user README.md . +COPY --chown=user LICENSE . + +# Copy python files. +#COPY --chown=user /lilac ./lilac/ + +COPY --chown=user docker_start.sh docker_start.py ./ + +# Make a local data directory for non-persistent storage demos. +RUN mkdir -p ./data +RUN chown -R user ./data + +CMD ["bash", "docker_start.sh"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..ee2d284c1bc2a474a74d00cd52ee7492f2a42c57 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 Lilac AI Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e76a83626e5193b1d5de4ca968d959b89ffef2b5 --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ +--- +app_port: 5432 +colorFrom: purple +colorTo: purple +datasets: [] +emoji: "\U0001F337" +sdk: docker +title: Lilac + +--- \ No newline at end of file diff --git a/data/.cache/lilac/concept/lilac/legal-termination/gte-small.pkl b/data/.cache/lilac/concept/lilac/legal-termination/gte-small.pkl deleted file mode 100644 index d66d70df3d047ca6a3a0fa8d45a9553a0d0f759b..0000000000000000000000000000000000000000 Binary files a/data/.cache/lilac/concept/lilac/legal-termination/gte-small.pkl and /dev/null differ diff --git a/data/.cache/lilac/concept/lilac/negative-sentiment/gte-small.pkl b/data/.cache/lilac/concept/lilac/negative-sentiment/gte-small.pkl deleted file mode 100644 index dfb2d44858b464765a7362c7c36eda7130c4b8e5..0000000000000000000000000000000000000000 Binary files a/data/.cache/lilac/concept/lilac/negative-sentiment/gte-small.pkl and /dev/null differ diff --git a/data/.cache/lilac/concept/lilac/non-english/gte-small.pkl b/data/.cache/lilac/concept/lilac/non-english/gte-small.pkl deleted file mode 100644 index 797f43f8bd5aa438c63d7da98f9cc298a47a81ab..0000000000000000000000000000000000000000 Binary files a/data/.cache/lilac/concept/lilac/non-english/gte-small.pkl and /dev/null differ diff --git a/data/.cache/lilac/concept/lilac/positive-sentiment/gte-small.pkl b/data/.cache/lilac/concept/lilac/positive-sentiment/gte-small.pkl deleted file mode 100644 index ba98d34570c8c63cc1275c3949759c05959638cf..0000000000000000000000000000000000000000 Binary files a/data/.cache/lilac/concept/lilac/positive-sentiment/gte-small.pkl and /dev/null differ diff --git a/data/.cache/lilac/concept/lilac/profanity/gte-small.pkl b/data/.cache/lilac/concept/lilac/profanity/gte-small.pkl deleted file mode 100644 index ca6a17557893c30ec2447b479a61b5d3f9316d0f..0000000000000000000000000000000000000000 --- a/data/.cache/lilac/concept/lilac/profanity/gte-small.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ed7340614b1dea910ddeb26bbda0167b1f4fe2479071a62a70b63c18bc6232d0 -size 1672960 diff --git a/data/.cache/lilac/concept/lilac/question/gte-small.pkl b/data/.cache/lilac/concept/lilac/question/gte-small.pkl deleted file mode 100644 index 9f6cb33d71165049846bbb6bbbb402acdfb6a177..0000000000000000000000000000000000000000 Binary files a/data/.cache/lilac/concept/lilac/question/gte-small.pkl and /dev/null differ diff --git a/data/.cache/lilac/concept/lilac/source-code/gte-small.pkl b/data/.cache/lilac/concept/lilac/source-code/gte-small.pkl deleted file mode 100644 index 855ed9a84d4eda5d1401bdbb632d48244d6b3b77..0000000000000000000000000000000000000000 Binary files a/data/.cache/lilac/concept/lilac/source-code/gte-small.pkl and /dev/null differ diff --git a/data/.cache/lilac/concept/lilac/toxicity/gte-small.pkl b/data/.cache/lilac/concept/lilac/toxicity/gte-small.pkl deleted file mode 100644 index 67108f91af650be214eeec055a7e96e6df9d0d40..0000000000000000000000000000000000000000 --- a/data/.cache/lilac/concept/lilac/toxicity/gte-small.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f2af2736f3d749391a431f9c24d3fc78cf8e58457cc4f0d1ce770185b92d879c -size 1886446 diff --git a/dist/lilac-0.0.13-py3-none-any.whl b/dist/lilac-0.0.13-py3-none-any.whl deleted file mode 100644 index 1bff43691cfc621c08d99710805573c65b7427fc..0000000000000000000000000000000000000000 --- a/dist/lilac-0.0.13-py3-none-any.whl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:54b11ac42070eb7829d429b070238a2f1f3a4f3adedbfe40d9f9119b570e8311 -size 1119815 diff --git a/docker_start.py b/docker_start.py new file mode 100644 index 0000000000000000000000000000000000000000..f5b508f5db62d1ce857364f9d95af0e258d4e109 --- /dev/null +++ b/docker_start.py @@ -0,0 +1,107 @@ +"""Startup work before running the web server.""" + +import os +import shutil +from typing import TypedDict + +import yaml +from huggingface_hub import scan_cache_dir, snapshot_download + +from lilac.concepts.db_concept import CONCEPTS_DIR, DiskConceptDB, get_concept_output_dir +from lilac.env import data_path, env +from lilac.utils import get_datasets_dir, get_lilac_cache_dir, log + + +def delete_old_files() -> None: + """Delete old files from the cache.""" + # Scan cache + try: + scan = scan_cache_dir() + except BaseException: + # Cache was not found. + return + + # Select revisions to delete + to_delete = [] + for repo in scan.repos: + latest_revision = max(repo.revisions, key=lambda x: x.last_modified) + to_delete.extend( + [revision.commit_hash for revision in repo.revisions if revision != latest_revision]) + strategy = scan.delete_revisions(*to_delete) + + # Delete them + log(f'Will delete {len(to_delete)} old revisions and save {strategy.expected_freed_size_str}') + strategy.execute() + + +class HfSpaceConfig(TypedDict): + """The huggingface space config, defined in README.md. + + See: + https://huggingface.co/docs/hub/spaces-config-reference + """ + title: str + datasets: list[str] + + +def main() -> None: + """Download dataset files from the HF space that was uploaded before building the image.""" + # SPACE_ID is the HuggingFace Space ID environment variable that is automatically set by HF. + repo_id = env('SPACE_ID', None) + if not repo_id: + return + + delete_old_files() + + with open(os.path.abspath('README.md')) as f: + # Strip the '---' for the huggingface readme config. + readme = f.read().strip().strip('---') + hf_config: HfSpaceConfig = yaml.safe_load(readme) + + # Download the huggingface space data. This includes code and datasets, so we move the datasets + # alone to the data directory. + for lilac_hf_dataset in hf_config['datasets']: + print('Downloading dataset from HuggingFace: ', lilac_hf_dataset) + snapshot_download( + repo_id=lilac_hf_dataset, + repo_type='dataset', + token=env('HF_ACCESS_TOKEN'), + local_dir=get_datasets_dir(data_path()), + ignore_patterns=['.gitattributes', 'README.md']) + + snapshot_dir = snapshot_download(repo_id=repo_id, repo_type='space', token=env('HF_ACCESS_TOKEN')) + # Copy datasets. + spaces_data_dir = os.path.join(snapshot_dir, 'data') + + # Delete cache files from persistent storage. + cache_dir = get_lilac_cache_dir(data_path()) + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + + # NOTE: This is temporary during the move of concepts into the pip package. Once all the demos + # have been updated, this block can be deleted. + old_lilac_concepts_data_dir = os.path.join(data_path(), CONCEPTS_DIR, 'lilac') + if os.path.exists(old_lilac_concepts_data_dir): + shutil.rmtree(old_lilac_concepts_data_dir) + + # Copy cache files from the space if they exist. + spaces_cache_dir = get_lilac_cache_dir(spaces_data_dir) + if os.path.exists(spaces_cache_dir): + shutil.copytree(spaces_cache_dir, cache_dir) + + # Copy concepts. + concepts = DiskConceptDB(spaces_data_dir).list() + for concept in concepts: + # Ignore lilac concepts, they're already part of the source code. + if concept.namespace == 'lilac': + continue + spaces_concept_output_dir = get_concept_output_dir(spaces_data_dir, concept.namespace, + concept.name) + persistent_output_dir = get_concept_output_dir(data_path(), concept.namespace, concept.name) + shutil.rmtree(persistent_output_dir, ignore_errors=True) + shutil.copytree(spaces_concept_output_dir, persistent_output_dir, dirs_exist_ok=True) + shutil.rmtree(spaces_concept_output_dir, ignore_errors=True) + + +if __name__ == '__main__': + main() diff --git a/docker_start.sh b/docker_start.sh new file mode 100644 index 0000000000000000000000000000000000000000..c9f1f4007fb0989fc4c3cb179573dea3f5be49a5 --- /dev/null +++ b/docker_start.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Fail if any of the commands below fail. +set -e + +python docker_start.py +gunicorn lilac.server:app \ + --bind 0.0.0.0:5432 \ + --preload -k uvicorn.workers.UvicornWorker \ + --timeout 120 diff --git a/lilac/__pycache__/__init__.cpython-39.pyc b/lilac/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 42deceaff8a3dda3e1d343b9b1cff3ccdb46bd42..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/auth.cpython-39.pyc b/lilac/__pycache__/auth.cpython-39.pyc deleted file mode 100644 index 82fbe5f16b3007637700a97330e686c9ed92ae7d..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/auth.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/batch_utils.cpython-39.pyc b/lilac/__pycache__/batch_utils.cpython-39.pyc deleted file mode 100644 index 512aa0a4cfa5ae504c485c3a7c88082243555f4f..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/batch_utils.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/batch_utils_test.cpython-39-pytest-7.4.0.pyc b/lilac/__pycache__/batch_utils_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 0eccd0dd470f530506a0cc1f5154120b655a02bd..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/batch_utils_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/__pycache__/cli.cpython-39.pyc b/lilac/__pycache__/cli.cpython-39.pyc deleted file mode 100644 index b2ec040f280e65b62b3aa219507a0fbd982a6a60..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/cli.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/config.cpython-39.pyc b/lilac/__pycache__/config.cpython-39.pyc deleted file mode 100644 index 021f0db9fdfd274717ea9496ed196b50be58ad49..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/config.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/conftest.cpython-39-pytest-7.4.0.pyc b/lilac/__pycache__/conftest.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 72f3ab617a0ed9285b5972c6824ca53ca9602fb3..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/conftest.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/__pycache__/data_loader.cpython-39.pyc b/lilac/__pycache__/data_loader.cpython-39.pyc deleted file mode 100644 index 8c58b90181f3c1f0c08d365182ef2bf48659e4d0..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/data_loader.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/data_loader_test.cpython-39-pytest-7.4.0.pyc b/lilac/__pycache__/data_loader_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index db2e4ae51c9a51773fc7ce605faadcbaaf8a5f5f..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/data_loader_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/__pycache__/db_manager.cpython-39.pyc b/lilac/__pycache__/db_manager.cpython-39.pyc deleted file mode 100644 index 4a363b60441c6b8432aeaedabe7916ae2b98dbed..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/db_manager.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/env.cpython-39.pyc b/lilac/__pycache__/env.cpython-39.pyc deleted file mode 100644 index 736f2c712f3f1dc459ed6d679145ff50b780eec5..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/env.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/load.cpython-39.pyc b/lilac/__pycache__/load.cpython-39.pyc deleted file mode 100644 index 920948797e02662a33809552b957a7d622a249c6..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/load.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/make_openapi.cpython-39.pyc b/lilac/__pycache__/make_openapi.cpython-39.pyc deleted file mode 100644 index 758ac12f8c7ce1384a5ff5d28bcb15a21a2a8c5d..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/make_openapi.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/parquet_writer.cpython-39.pyc b/lilac/__pycache__/parquet_writer.cpython-39.pyc deleted file mode 100644 index fb19bc8a469d5020f323b70d40329f1528010413..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/parquet_writer.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/router_concept.cpython-39.pyc b/lilac/__pycache__/router_concept.cpython-39.pyc deleted file mode 100644 index 98561ae19f23e8de3eb888d481573564ca733009..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/router_concept.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/router_data_loader.cpython-39.pyc b/lilac/__pycache__/router_data_loader.cpython-39.pyc deleted file mode 100644 index 46b81407e8efad6da2cbdc0e02e2b8b009ff0897..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/router_data_loader.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/router_dataset.cpython-39.pyc b/lilac/__pycache__/router_dataset.cpython-39.pyc deleted file mode 100644 index 90020b4de5ffbc2e03ccb8803c6f6caf1db2a736..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/router_dataset.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/router_google_login.cpython-39.pyc b/lilac/__pycache__/router_google_login.cpython-39.pyc deleted file mode 100644 index f44c08e9f38c80dcc625d653fd32c654144ddf7b..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/router_google_login.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/router_signal.cpython-39.pyc b/lilac/__pycache__/router_signal.cpython-39.pyc deleted file mode 100644 index d859bdaafd10772fe7067a58fe9a572e0d516b5b..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/router_signal.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/router_tasks.cpython-39.pyc b/lilac/__pycache__/router_tasks.cpython-39.pyc deleted file mode 100644 index 50e912360e76dccedd587decd4eba5c0de079d36..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/router_tasks.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/router_utils.cpython-39.pyc b/lilac/__pycache__/router_utils.cpython-39.pyc deleted file mode 100644 index cbad6fe16f6a44b9af9b396ee16e7f28c490c381..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/router_utils.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/schema.cpython-39.pyc b/lilac/__pycache__/schema.cpython-39.pyc deleted file mode 100644 index 74e91e673b161f42c75f4d23b8bd9d958228be5b..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/schema.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/schema_test.cpython-39-pytest-7.4.0.pyc b/lilac/__pycache__/schema_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 3127d9fb954e62c0179230ad5ce26217eda82878..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/schema_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/__pycache__/server.cpython-39.pyc b/lilac/__pycache__/server.cpython-39.pyc deleted file mode 100644 index 9ebccc689d0cd9bb8885e2e755bd127b0a4f7d02..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/server.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/server_concept_test.cpython-39-pytest-7.4.0.pyc b/lilac/__pycache__/server_concept_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 279db96d4019c1f53ad383c8a39189bad69855ec..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/server_concept_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/__pycache__/server_dataset_test.cpython-39-pytest-7.4.0.pyc b/lilac/__pycache__/server_dataset_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 33ec99104b6899dc8adac0a6a7cc48f48f11a538..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/server_dataset_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/__pycache__/server_signal_test.cpython-39-pytest-7.4.0.pyc b/lilac/__pycache__/server_signal_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 304018cdb3f9faf85a4e08e56ca3f650008b52c5..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/server_signal_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/__pycache__/signal.cpython-39.pyc b/lilac/__pycache__/signal.cpython-39.pyc deleted file mode 100644 index b89f83b0f4c10d0a0de36641bd373b9ee25dd4ea..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/signal.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/signal_test.cpython-39-pytest-7.4.0.pyc b/lilac/__pycache__/signal_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 7835457677457f0e217df063e7319c4bb4130a81..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/signal_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/__pycache__/tasks.cpython-39.pyc b/lilac/__pycache__/tasks.cpython-39.pyc deleted file mode 100644 index d9976b17eb92cd805e993a2628b685d13afd7cb9..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/tasks.cpython-39.pyc and /dev/null differ diff --git a/lilac/__pycache__/test_utils.cpython-39-pytest-7.4.0.pyc b/lilac/__pycache__/test_utils.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 2201d1d99817d64b3867fdcb424a3d394b24e305..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/test_utils.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/__pycache__/utils.cpython-39.pyc b/lilac/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index e286f3f8ab75ef764f2a3d91636906613e57a78c..0000000000000000000000000000000000000000 Binary files a/lilac/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/lilac/batch_utils_test.py b/lilac/batch_utils_test.py deleted file mode 100644 index 81ff1d68f883320c14dc7e26673c20aef28e1a24..0000000000000000000000000000000000000000 --- a/lilac/batch_utils_test.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Test batch_utils.py.""" -from typing import Iterable - -import numpy as np - -from .batch_utils import deep_flatten, deep_unflatten, flat_batched_compute, flatten, unflatten - - -def test_batched_compute() -> None: - input = [[1], [2, 3], [4, 5]] - batch_size = 2 # Does not evenly split any input - - def f(inputs: Iterable[int]) -> list[int]: - return [x * x for x in inputs] - - assert list(flat_batched_compute(input, f, batch_size)) == [[1], [4, 9], [16, 25]] - - -def test_batched_compute_np() -> None: - input = [[np.array([1, 1])], [np.array([2, 2]), np.array([3, 3])], - [np.array([4, 4]), np.array([5, 5])]] - batch_size = 2 # Does not evenly split any input - - def f(inputs: Iterable[np.ndarray]) -> Iterable[float]: - return [x[0] * x[0] for x in inputs] - - assert list(flat_batched_compute(input, f, batch_size)) == [[1], [4, 9], [16, 25]] - - -def test_flatten() -> None: - a = [[1, 2], [3], [4, 5, 5]] - result = list(flatten(a)) - assert result == [1, 2, 3, 4, 5, 5] - - -def test_unflatten() -> None: - a = [[1, 2], [3], [4, 5, 5]] - flat_a = list(flatten(a)) - result = list(unflatten(flat_a, a)) - assert result == [[1, 2], [3], [4, 5, 5]] - - -def test_deep_flatten() -> None: - a = [[1, 2], [[3]], [4, 5, [5]]] - result = list(deep_flatten(a)) - assert result == [1, 2, 3, 4, 5, 5] - - -def test_deep_flatten_primitive() -> None: - result = list(deep_flatten('hello')) - assert result == ['hello'] - - -def test_deep_flatten_np() -> None: - input = [ - [np.array([1, 1])], - [np.array([2, 2]), np.array([3, 3])], - ] - result = list(deep_flatten(input)) - - assert len(result) == 3 - np.testing.assert_array_equal(result[0], np.array([1, 1])) - np.testing.assert_array_equal(result[1], np.array([2, 2])) - np.testing.assert_array_equal(result[2], np.array([3, 3])) - - -def test_deep_unflatten() -> None: - a = [[1, 2], [[3]], [4, 5, 5]] - flat_a = list(deep_flatten(a)) - result = deep_unflatten(flat_a, a) - assert result == [[1, 2], [[3]], [4, 5, 5]] - - -def test_deep_unflatten_primitive() -> None: - original = 'hello' - result = deep_unflatten(['hello'], original) - assert result == 'hello' - - -def test_deep_unflatten_primitive_list() -> None: - original = ['hello', 'world'] - result = deep_unflatten(['hello', 'world'], original) - assert result == ['hello', 'world'] - - -def test_deep_unflatten_np() -> None: - input = [ - [np.array([1, 1])], - [np.array([2, 2]), np.array([3, 3])], - ] - result = list(deep_unflatten(deep_flatten(input), input)) - - assert len(result) == 2 - np.testing.assert_array_equal(result[0], [np.array([1, 1])]) - np.testing.assert_array_equal(result[1], [np.array([2, 2]), np.array([3, 3])]) diff --git a/lilac/concepts/__pycache__/__init__.cpython-39.pyc b/lilac/concepts/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 70a7cf59bc3cb3073107e4a7f2b730e2d7952b56..0000000000000000000000000000000000000000 Binary files a/lilac/concepts/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/lilac/concepts/__pycache__/concept.cpython-39.pyc b/lilac/concepts/__pycache__/concept.cpython-39.pyc deleted file mode 100644 index 430383650afea3915f26985546f9146f29deb910..0000000000000000000000000000000000000000 Binary files a/lilac/concepts/__pycache__/concept.cpython-39.pyc and /dev/null differ diff --git a/lilac/concepts/__pycache__/concept_test.cpython-39-pytest-7.4.0.pyc b/lilac/concepts/__pycache__/concept_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 89f040e209b271db1ec8d6b0f53f4aab25ea59e4..0000000000000000000000000000000000000000 Binary files a/lilac/concepts/__pycache__/concept_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/concepts/__pycache__/db_concept.cpython-39.pyc b/lilac/concepts/__pycache__/db_concept.cpython-39.pyc deleted file mode 100644 index 270737810baf3928230bce7405aef040fdff2512..0000000000000000000000000000000000000000 Binary files a/lilac/concepts/__pycache__/db_concept.cpython-39.pyc and /dev/null differ diff --git a/lilac/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.4.0.pyc b/lilac/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index c882a5b39f17bcde3115597ebcbb02b94421dcc1..0000000000000000000000000000000000000000 Binary files a/lilac/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/concepts/concept_test.py b/lilac/concepts/concept_test.py deleted file mode 100644 index 57b21944ffe7e28eb15f9bced2d3b6b17254bdc6..0000000000000000000000000000000000000000 --- a/lilac/concepts/concept_test.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Tests for concept.""" - -from ..schema import SignalInputType -from .concept import DRAFT_MAIN, Concept, Example, draft_examples - - -def test_draft_examples_main() -> None: - concept = Concept( - namespace='test_namespace', - concept_name='test_name', - type=SignalInputType.TEXT, - data={ - '0': Example(id='0', label=True, text='hello'), - '1': Example(id='1', label=False, text='world'), - }, - version=0) - - assert draft_examples(concept, DRAFT_MAIN) == { - '0': Example(id='0', label=True, text='hello'), - '1': Example(id='1', label=False, text='world'), - } - - -def test_draft_examples_simple_draft() -> None: - concept = Concept( - namespace='test_namespace', - concept_name='test_name', - type=SignalInputType.TEXT, - data={ - '0': Example(id='0', label=True, text='hello'), - '1': Example(id='1', label=False, text='world'), - '2': Example(id='2', label=True, text='hello draft 1', draft='draft1'), - '3': Example(id='3', label=False, text='world draft 1', draft='draft1'), - '4': Example(id='4', label=True, text='hello draft 2', draft='draft2'), - '5': Example(id='5', label=False, text='world draft 2', draft='draft2'), - }, - version=0) - - assert draft_examples(concept, DRAFT_MAIN) == { - '0': Example(id='0', label=True, text='hello'), - '1': Example(id='1', label=False, text='world'), - } - - assert draft_examples(concept, 'draft1') == { - '0': Example(id='0', label=True, text='hello'), - '1': Example(id='1', label=False, text='world'), - '2': Example(id='2', label=True, text='hello draft 1', draft='draft1'), - '3': Example(id='3', label=False, text='world draft 1', draft='draft1'), - } - - assert draft_examples(concept, 'draft2') == { - '0': Example(id='0', label=True, text='hello'), - '1': Example(id='1', label=False, text='world'), - '4': Example(id='4', label=True, text='hello draft 2', draft='draft2'), - '5': Example(id='5', label=False, text='world draft 2', draft='draft2'), - } - - -def test_draft_examples_draft_dedupe() -> None: - concept = Concept( - namespace='test_namespace', - concept_name='test_name', - type=SignalInputType.TEXT, - data={ - '0': Example(id='0', label=True, text='hello'), - '1': Example(id='1', label=False, text='world'), - # Duplicate text. - '2': Example(id='2', label=False, text='hello', draft='draft'), - '3': Example(id='3', label=False, text='world draft', draft='draft'), - }, - version=0) - - assert draft_examples(concept, DRAFT_MAIN) == { - '0': Example(id='0', label=True, text='hello'), - '1': Example(id='1', label=False, text='world'), - } - - assert draft_examples(concept, 'draft') == { - # 0 is deduplicated with 2. - '1': Example(id='1', label=False, text='world'), - # 2 overrides 0's label. - '2': Example(id='2', label=False, text='hello', draft='draft'), - '3': Example(id='3', label=False, text='world draft', draft='draft'), - } diff --git a/lilac/concepts/db_concept_test.py b/lilac/concepts/db_concept_test.py deleted file mode 100644 index b19c73edf6dde380771788fe24a874954b081d93..0000000000000000000000000000000000000000 --- a/lilac/concepts/db_concept_test.py +++ /dev/null @@ -1,621 +0,0 @@ -"""Tests for the the database concept.""" - -import os -from pathlib import Path -from typing import Generator, Iterable, Type, cast - -import numpy as np -import pytest -from pytest_mock import MockerFixture -from sklearn.preprocessing import normalize -from typing_extensions import override - -from ..data.dataset_duckdb import DatasetDuckDB -from ..db_manager import set_default_dataset_cls -from ..schema import Item, RichData, lilac_embedding -from ..signal import TextEmbeddingSignal, clear_signal_registry, register_signal -from .concept import ( - DRAFT_MAIN, - Concept, - ConceptModel, - ConceptType, - DraftId, - Example, - ExampleIn, - LogisticEmbeddingModel, -) -from .db_concept import ( - ConceptACL, - ConceptDB, - ConceptInfo, - ConceptModelDB, - ConceptUpdate, - DiskConceptDB, - DiskConceptModelDB, -) - -ALL_CONCEPT_DBS = [DiskConceptDB] -ALL_CONCEPT_MODEL_DBS = [DiskConceptModelDB] - - -@pytest.fixture(autouse=True) -def set_data_path(tmp_path: Path, mocker: MockerFixture) -> None: - mocker.patch.dict(os.environ, {'LILAC_DATA_PATH': str(tmp_path)}) - - -EMBEDDING_MAP: dict[str, list[float]] = { - 'not in concept': [1.0, 0.0, 0.0], - 'in concept': [0.9, 0.1, 0.0], - 'a new data point': [0.1, 0.2, 0.3], - 'a true draft point': [0.4, 0.5, 0.6], - 'a false draft point': [0.7, 0.8, 0.9], -} - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Embed the examples, use a hashmap to the vector for simplicity.""" - for example in data: - if example not in EMBEDDING_MAP: - raise ValueError(f'Example "{str(example)}" not in embedding map') - yield [lilac_embedding(0, len(example), np.array(EMBEDDING_MAP[cast(str, example)]))] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Generator: - set_default_dataset_cls(DatasetDuckDB) - register_signal(TestEmbedding) - - # Unit test runs. - yield - - # Teardown. - clear_signal_registry() - - -@pytest.mark.parametrize('db_cls', ALL_CONCEPT_DBS) -class ConceptDBSuite: - - def test_list_lilac_concepts(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - # Make sure a lilac concept exists in the default list. - assert filter(lambda c: c.name == 'positive-sentiment' and c.namespace == 'lilac', db.list()) - - def test_create_concept(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - db.create(namespace='test', name='test_concept', type=ConceptType.TEXT) - - # Remove lilac concepts. - concepts = list(filter(lambda c: c.namespace != 'lilac', db.list())) - - assert concepts == [ - ConceptInfo( - namespace='test', - name='test_concept', - type=ConceptType.TEXT, - drafts=[DRAFT_MAIN], - acls=ConceptACL(read=True, write=True)) - ] - - # Make sure list with drafts relects the drafts. - train_data = [ - ExampleIn(label=False, text='not in concept', draft='test_draft'), - ExampleIn(label=True, text='in concept', draft='test_draft') - ] - db.edit('test', 'test_concept', ConceptUpdate(insert=train_data)) - - # Remove lilac concepts. - concepts = list(filter(lambda c: c.namespace != 'lilac', db.list())) - - assert concepts == [ - ConceptInfo( - namespace='test', - name='test_concept', - type=ConceptType.TEXT, - drafts=[DRAFT_MAIN, 'test_draft'], - acls=ConceptACL(read=True, write=True)) - ] - - def test_add_example(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - namespace = 'test' - concept_name = 'test_concept' - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept') - ] - db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT) - db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - concept = db.get(namespace, concept_name) - - assert concept is not None - - keys = list(concept.data.keys()) - assert concept == Concept( - namespace=namespace, - concept_name=concept_name, - type=ConceptType.TEXT, - data={ - keys[0]: Example(id=keys[0], label=False, text='not in concept'), - keys[1]: Example(id=keys[1], label=True, text='in concept') - }, - version=1) - - # Add a draft labels. - db.edit( - namespace, concept_name, - ConceptUpdate(insert=[ - ExampleIn(label=False, text='really not in concept', draft='test_draft'), - ExampleIn(label=True, text='really in concept', draft='test_draft') - ])) - - concept = db.get(namespace, concept_name) - assert concept is not None - - keys = list(concept.data.keys()) - assert concept == Concept( - namespace=namespace, - concept_name=concept_name, - type=ConceptType.TEXT, - data={ - keys[0]: Example(id=keys[0], label=False, text='not in concept'), - keys[1]: Example(id=keys[1], label=True, text='in concept'), - keys[2]: Example(id=keys[2], label=False, text='really not in concept', draft='test_draft'), - keys[3]: Example(id=keys[3], label=True, text='really in concept', draft='test_draft'), - }, - version=2) - - def test_update_concept(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - namespace = 'test' - concept_name = 'test_concept' - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept'), - ExampleIn(label=False, text='really not in concept', draft='test_draft'), - ExampleIn(label=True, text='really in concept', draft='test_draft') - ] - db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT) - db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - concept = db.get(namespace, concept_name) - assert concept is not None - - keys = list(concept.data.keys()) - # Edit the first example. - db.edit( - namespace, concept_name, - ConceptUpdate(update=[Example(id=keys[0], label=False, text='not in concept, updated')])) - concept = db.get(namespace, concept_name) - - assert concept == Concept( - namespace=namespace, - concept_name=concept_name, - type=ConceptType.TEXT, - data={ - # The first example should be updated alone. - keys[0]: Example(id=keys[0], label=False, text='not in concept, updated'), - keys[1]: Example(id=keys[1], label=True, text='in concept'), - # Drafts are untouched. - keys[2]: Example(id=keys[2], label=False, text='really not in concept', draft='test_draft'), - keys[3]: Example(id=keys[3], label=True, text='really in concept', draft='test_draft'), - }, - version=2) - - # Edit the second example on the draft. - db.edit( - namespace, concept_name, - ConceptUpdate(update=[ - Example(id=keys[3], label=True, text='really in concept, updated', draft='test_draft') - ])) - concept = db.get(namespace, concept_name) - - assert concept == Concept( - namespace=namespace, - concept_name=concept_name, - type=ConceptType.TEXT, - data={ - # Main remains the same. - keys[0]: Example(id=keys[0], label=False, text='not in concept, updated'), - keys[1]: Example(id=keys[1], label=True, text='in concept'), - keys[2]: Example(id=keys[2], label=False, text='really not in concept', draft='test_draft'), - keys[3]: Example( - id=keys[3], label=True, text='really in concept, updated', draft='test_draft'), - }, - version=3) - - def test_remove_concept(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - namespace = 'test' - concept_name = 'test_concept' - db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT) - - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept') - ] - db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - concept = db.get(namespace, concept_name) - - db.remove(namespace, concept_name) - - concept = db.get(namespace, concept_name) - - assert concept is None - - def test_remove_concept_examples(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - namespace = 'test' - concept_name = 'test_concept' - db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT) - - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept') - ] - db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - concept = db.get(namespace, concept_name) - assert concept is not None - - keys = list(concept.data.keys()) - - db.edit(namespace, concept_name, ConceptUpdate(remove=[keys[0]])) - concept = db.get(namespace, concept_name) - - assert concept == Concept( - namespace=namespace, - concept_name=concept_name, - type=ConceptType.TEXT, - data={ - # key_0 was removed. - keys[1]: Example(id=keys[1], label=True, text='in concept') - }, - version=2) - - def test_remove_concept_examples_draft(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - namespace = 'test' - concept_name = 'test_concept' - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept'), - ExampleIn(label=False, text='really not in concept', draft='test_draft'), - ExampleIn(label=True, text='really in concept', draft='test_draft') - ] - db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT) - db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - concept = db.get(namespace, concept_name) - assert concept is not None - - keys = list(concept.data.keys()) - - db.edit(namespace, concept_name, ConceptUpdate(remove=[keys[2]])) - concept = db.get(namespace, concept_name) - - assert concept == Concept( - namespace=namespace, - concept_name=concept_name, - type=ConceptType.TEXT, - data={ - keys[0]: Example(id=keys[0], label=False, text='not in concept'), - keys[1]: Example(id=keys[1], label=True, text='in concept'), - # The first draft example is removed. - keys[3]: Example(id=keys[3], label=True, text='really in concept', draft='test_draft'), - }, - version=2) - - def test_remove_invalid_id(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - namespace = 'test' - concept_name = 'test_concept' - db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT) - - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept'), - ExampleIn(label=False, text='really not in concept', draft='test_draft'), - ExampleIn(label=True, text='really in concept', draft='test_draft') - ] - db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - with pytest.raises(ValueError, match='Example with id "invalid_id" does not exist'): - db.edit(namespace, concept_name, ConceptUpdate(remove=['invalid_id'])) - - def test_edit_before_creation(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - namespace = 'test' - concept_name = 'test_concept' - - with pytest.raises( - ValueError, match='Concept with namespace "test" and name "test_concept" does not exist'): - db.edit(namespace, concept_name, - ConceptUpdate(insert=[ - ExampleIn(label=False, text='not in concept'), - ])) - - def test_edit_invalid_id(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - namespace = 'test' - concept_name = 'test_concept' - db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT) - - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept') - ] - db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - with pytest.raises(ValueError, match='Example with id "invalid_id" does not exist'): - db.edit(namespace, concept_name, - ConceptUpdate(update=[Example(id='invalid_id', label=False, text='not in concept')])) - - def test_merge_draft(self, db_cls: Type[ConceptDB]) -> None: - db = db_cls() - namespace = 'test' - concept_name = 'test_concept' - db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT) - - train_data = [ - ExampleIn(label=True, text='hello'), - ExampleIn(label=False, text='world'), - ExampleIn(label=True, text='hello draft 1', draft='draft1'), - ExampleIn(label=False, text='world draft 1', draft='draft1'), - # Duplicate of main. - ExampleIn(label=False, text='hello', draft='draft2'), - ExampleIn(label=True, text='world draft 2', draft='draft2'), - ] - db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - db.merge_draft(namespace, concept_name, 'draft1') - - concept = db.get(namespace, concept_name) - assert concept is not None - keys = list(concept.data.keys()) - - assert concept.dict() == Concept( - namespace='test', - concept_name='test_concept', - type=ConceptType.TEXT, - data={ - keys[0]: Example(id=keys[0], label=True, text='hello'), - keys[1]: Example(id=keys[1], label=False, text='world'), - # Draft examples are merged. - keys[2]: Example(id=keys[2], label=True, text='hello draft 1'), - keys[3]: Example(id=keys[3], label=False, text='world draft 1'), - # Draft 2 is untouched. - keys[4]: Example(id=keys[4], label=False, text='hello', draft='draft2'), - keys[5]: Example(id=keys[5], label=True, text='world draft 2', draft='draft2'), - }, - version=2).dict() - - db.merge_draft(namespace, concept_name, 'draft2') - - concept = db.get(namespace, concept_name) - assert concept is not None - - assert concept == Concept( - namespace='test', - concept_name='test_concept', - type=ConceptType.TEXT, - data={ - # The first example is a duplicate of the label from the draft, so it is removed. - keys[1]: Example(id=keys[1], label=False, text='world'), - # Draft examples are merged. - keys[2]: Example(id=keys[2], label=True, text='hello draft 1'), - keys[3]: Example(id=keys[3], label=False, text='world draft 1'), - # Draft examples are merged. - keys[4]: Example(id=keys[4], label=False, text='hello'), - keys[5]: Example(id=keys[5], label=True, text='world draft 2'), - }, - version=3) - - -def _make_test_concept_model( - concept_db: ConceptDB, - model_db: ConceptModelDB, - logistic_models: dict[DraftId, LogisticEmbeddingModel] = {}) -> ConceptModel: - namespace = 'test' - concept_name = 'test_concept' - concept_db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT) - - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept') - ] - concept_db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - model = model_db.create(namespace, concept_name, embedding_name='test_embedding') - model._logistic_models = logistic_models - model_db._save(model) - return model - - -class TestLogisticModel(LogisticEmbeddingModel): - - @override - def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray: - """Get the scores for the provided embeddings.""" - return np.array([.1]) - - @override - def fit(self, embeddings: np.ndarray, labels: list[bool]) -> None: - pass - - -@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS) -@pytest.mark.parametrize('model_db_cls', ALL_CONCEPT_MODEL_DBS) -class ConceptModelDBSuite: - - def test_save_and_get_model(self, concept_db_cls: Type[ConceptDB], - model_db_cls: Type[ConceptModelDB]) -> None: - concept_db = concept_db_cls() - model_db = model_db_cls(concept_db) - model = _make_test_concept_model(concept_db, model_db) - model = model_db.sync(model.namespace, model.concept_name, model.embedding_name) - retrieved_model = model_db.get( - namespace='test', concept_name='test_concept', embedding_name='test_embedding') - if not retrieved_model: - retrieved_model = model_db.create( - namespace='test', concept_name='test_concept', embedding_name='test_embedding') - assert retrieved_model.namespace == model.namespace - assert retrieved_model.concept_name == model.concept_name - assert retrieved_model.embedding_name == model.embedding_name - assert retrieved_model.version == model.version - - def test_sync_model(self, concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB], - mocker: MockerFixture) -> None: - - concept_db = concept_db_cls() - model_db = model_db_cls(concept_db) - logistic_model = TestLogisticModel() - score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings') - fit_mock = mocker.spy(TestLogisticModel, 'fit') - - model = _make_test_concept_model( - concept_db, model_db, logistic_models={DRAFT_MAIN: logistic_model}) - - assert model_db.in_sync(model) is False - assert score_embeddings_mock.call_count == 0 - assert fit_mock.call_count == 0 - - model = model_db.sync(model.namespace, model.concept_name, model.embedding_name) - - assert model_db.in_sync(model) is True - assert score_embeddings_mock.call_count == 0 - assert fit_mock.call_count == 1 - - def test_out_of_sync_model(self, concept_db_cls: Type[ConceptDB], - model_db_cls: Type[ConceptModelDB], mocker: MockerFixture) -> None: - concept_db = concept_db_cls() - model_db = model_db_cls(concept_db) - score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings') - fit_mock = mocker.spy(TestLogisticModel, 'fit') - logistic_model = TestLogisticModel() - model = _make_test_concept_model( - concept_db, model_db, logistic_models={DRAFT_MAIN: logistic_model}) - model = model_db.sync(model.namespace, model.concept_name, model.embedding_name) - assert model_db.in_sync(model) is True - assert score_embeddings_mock.call_count == 0 - assert fit_mock.call_count == 1 - - (called_model, called_embeddings, called_labels) = fit_mock.call_args_list[-1].args - assert called_model == logistic_model - np.testing.assert_array_equal( - called_embeddings, - normalize(np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))) - assert called_labels == [False, True] - - # Edit the concept. - concept_db.edit('test', 'test_concept', - ConceptUpdate(insert=[ExampleIn(label=False, text='a new data point')])) - - # Make sure the model is out of sync. - assert model_db.in_sync(model) is False - assert score_embeddings_mock.call_count == 0 - assert fit_mock.call_count == 1 - - model = model_db.sync(model.namespace, model.concept_name, model.embedding_name) - assert model_db.in_sync(model) is True - assert score_embeddings_mock.call_count == 0 - assert fit_mock.call_count == 2 - # Fit is called again with new points on main only. - (called_model, called_embeddings, called_labels) = fit_mock.call_args_list[-1].args - assert called_model == logistic_model - np.testing.assert_array_equal( - called_embeddings, - normalize( - np.array([ - EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept'], - EMBEDDING_MAP['a new data point'] - ]))) - assert called_labels == [False, True, False] - - def test_out_of_sync_draft_model(self, concept_db_cls: Type[ConceptDB], - model_db_cls: Type[ConceptModelDB], - mocker: MockerFixture) -> None: - concept_db = concept_db_cls() - model_db = model_db_cls(concept_db) - score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings') - fit_mock = mocker.spy(TestLogisticModel, 'fit') - main_model = TestLogisticModel() - draft_model = TestLogisticModel() - model = _make_test_concept_model( - concept_db, model_db, logistic_models={ - DRAFT_MAIN: main_model, - 'test_draft': draft_model - }) - model = model_db.sync(model.namespace, model.concept_name, model.embedding_name) - assert model_db.in_sync(model) is True - assert score_embeddings_mock.call_count == 0 - assert fit_mock.call_count == 1 - - # Make sure drafts cause the model to be out of sync. - concept_db.edit( - 'test', - 'test_concept', - ConceptUpdate(insert=[ - ExampleIn(label=True, text='a true draft point', draft='test_draft'), - ExampleIn(label=False, text='a false draft point', draft='test_draft'), - # This point exists in main, but we switched the label. - ExampleIn(label=False, text='in concept', draft='test_draft'), - ])) - - # Make sure the model is out of sync. - assert model_db.in_sync(model) is False - assert score_embeddings_mock.call_count == 0 - assert fit_mock.call_count == 1 - - model = model_db.sync(model.namespace, model.concept_name, model.embedding_name) - assert model_db.in_sync(model) is True - assert score_embeddings_mock.call_count == 0 - assert fit_mock.call_count == 3 # Fit is called on both the draft, and main. - - # Fit is called again with the same points. - ((called_model, called_embeddings, called_labels), - (called_draft_model, called_draft_embeddings, called_draft_labels)) = ( - c.args for c in fit_mock.call_args_list[-2:]) - - # The draft model is called with the data from main, and the data from draft. - assert called_draft_model == draft_model - np.testing.assert_array_equal( - called_draft_embeddings, - normalize( - np.array([ - EMBEDDING_MAP['a true draft point'], EMBEDDING_MAP['a false draft point'], - EMBEDDING_MAP['in concept'], EMBEDDING_MAP['not in concept'] - ]))) - assert called_draft_labels == [ - True, - False, - # This was overriden by the draft. - False, - False - ] - - # The main model was fit without the data from the draft. - assert called_model == main_model - np.testing.assert_array_equal( - called_embeddings, - normalize(np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))) - assert called_labels == [False, True] - - def test_embedding_not_found_in_map(self, concept_db_cls: Type[ConceptDB], - model_db_cls: Type[ConceptModelDB]) -> None: - concept_db = concept_db_cls() - model_db = model_db_cls(concept_db) - model = _make_test_concept_model(concept_db, model_db) - model = model_db.sync(model.namespace, model.concept_name, model.embedding_name) - - # Edit the concept. - concept_db.edit('test', 'test_concept', - ConceptUpdate(insert=[ExampleIn(label=False, text='unknown text')])) - - # Make sure the model is out of sync. - assert model_db.in_sync(model) is False - - with pytest.raises(ValueError, match='Example "unknown text" not in embedding map'): - model_db.sync(model.namespace, model.concept_name, model.embedding_name) diff --git a/lilac/data/__pycache__/__init__.cpython-39.pyc b/lilac/data/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index a2482cd5825af00ea0d09990195a01cee8022e52..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset.cpython-39.pyc b/lilac/data/__pycache__/dataset.cpython-39.pyc deleted file mode 100644 index 23a8cc240afab0b94c937507773a67c2c8ff0d1e..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset.cpython-39.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_compute_signal_chain_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_compute_signal_chain_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 77b6e152a814291f96d53d2fe59ee0a20c08fcae..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_compute_signal_chain_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_compute_signal_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_compute_signal_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 62f6f6f2283c6bb39d74d038f3e2c35225859b8b..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_compute_signal_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_config_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_config_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index cb7e9efeee76894d7d22917985e9b88624cdbced..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_config_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_duckdb.cpython-39.pyc b/lilac/data/__pycache__/dataset_duckdb.cpython-39.pyc deleted file mode 100644 index 1f335238d4c9d54c574b6a18b68d7ab08cfc26c9..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_duckdb.cpython-39.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_export_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_export_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 481072432490c9f5f03100dd603c0cf0c5c4e94f..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_export_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 97e04a06837a131317d9ee2554b6d0815027d955..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 062dfb9c3dfe0d4c4b4e85af5f3f0c9cb77ff0bd..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_select_rows_schema_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_select_rows_schema_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 4da074e732cd4f55431cc60f37004a4727b462f4..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_select_rows_schema_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index d938b35b2a64dd56c27fc9586559a109b4adc492..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_select_rows_sort_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_select_rows_sort_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 6b7990dcc80a36c0a1c5653beb23b2507920ea75..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_select_rows_sort_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_select_rows_udf_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_select_rows_udf_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 7daa3c8cb03111767a05c04c206b179a4d7bf4fe..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_select_rows_udf_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index aa5527bf96b542458b016df29a894a00846fc071..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 080deea825ca2244fe6496419121fa1ae0118331..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_test_utils.cpython-39.pyc b/lilac/data/__pycache__/dataset_test_utils.cpython-39.pyc deleted file mode 100644 index 226727a35c8a741f9a84e4341a0047fc1d523068..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_test_utils.cpython-39.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_utils.cpython-39.pyc b/lilac/data/__pycache__/dataset_utils.cpython-39.pyc deleted file mode 100644 index 1ed13b5e27f15bd45ef2c872133119b49130ff16..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_utils.cpython-39.pyc and /dev/null differ diff --git a/lilac/data/__pycache__/dataset_utils_test.cpython-39-pytest-7.4.0.pyc b/lilac/data/__pycache__/dataset_utils_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 6dfd55b9b2a5d1df17096f15aba400363d3878bc..0000000000000000000000000000000000000000 Binary files a/lilac/data/__pycache__/dataset_utils_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/data/dataset_compute_signal_chain_test.py b/lilac/data/dataset_compute_signal_chain_test.py deleted file mode 100644 index 9fec9caa48c075fd2b5e521fb20581f6ece583e4..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_compute_signal_chain_test.py +++ /dev/null @@ -1,205 +0,0 @@ -"""Tests for dataset.compute_signal() when signals are chained.""" - -import re -from typing import Iterable, List, Optional, cast - -import numpy as np -import pytest -from pytest_mock import MockerFixture -from typing_extensions import override - -from ..embeddings.vector_store import VectorDBIndex -from ..schema import ( - EMBEDDING_KEY, - Field, - Item, - PathKey, - RichData, - SignalInputType, - field, - lilac_embedding, - lilac_span, - schema, -) -from ..signal import ( - TextEmbeddingSignal, - TextSignal, - TextSplitterSignal, - VectorSignal, - clear_signal_registry, - register_signal, -) -from .dataset import DatasetManifest -from .dataset_test_utils import TEST_DATASET_NAME, TEST_NAMESPACE, TestDataMaker, enriched_item - -SIMPLE_ITEMS: list[Item] = [{ - 'str': 'a', - 'int': 1, - 'bool': False, - 'float': 3.0 -}, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 2.0 -}, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 1.0 -}] - -EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]), - ('hello2.', [1.0, 1.0, 0.0]), - ('hello world.', [1.0, 1.0, 1.0]), - ('hello world2.', [2.0, 1.0, 1.0])] - -STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} - - -class TestSplitter(TextSplitterSignal): - """Split documents into sentence by splitting on period.""" - name = 'test_splitter' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - for text in data: - if not isinstance(text, str): - raise ValueError(f'Expected text to be a string, got {type(text)} instead.') - sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence] - yield [ - lilac_span(text.index(sentence), - text.index(sentence) + len(sentence)) for sentence in sentences - ] - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Call the embedding function.""" - for example in data: - yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))] - - -class TestEmbeddingSumSignal(VectorSignal): - """Sums the embeddings to return a single floating point value.""" - name = 'test_embedding_sum' - input_type = SignalInputType.TEXT - - @override - def fields(self) -> Field: - return field('float32') - - @override - def vector_compute(self, keys: Iterable[PathKey], vector_index: VectorDBIndex) -> Iterable[Item]: - # The signal just sums the values of the embedding. - all_vector_spans = vector_index.get(keys) - for vector_spans in all_vector_spans: - yield vector_spans[0]['vector'].sum() - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_signal(TestSplitter) - register_signal(TestEmbedding) - register_signal(TestEmbeddingSumSignal) - register_signal(NamedEntity) - # Unit test runs. - yield - # Teardown. - clear_signal_registry() - - -def test_manual_embedding_signal(make_test_data: TestDataMaker, mocker: MockerFixture) -> None: - dataset = make_test_data([{'text': 'hello.'}, {'text': 'hello2.'}]) - - embed_mock = mocker.spy(TestEmbedding, 'compute') - dataset.compute_embedding('test_embedding', 'text') - embedding_sum_signal = TestEmbeddingSumSignal(embedding='test_embedding') - dataset.compute_signal(embedding_sum_signal, 'text') - - # Make sure the embedding signal is not called twice. - assert embed_mock.call_count == 1 - - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'text': field( - 'string', - fields={ - 'test_embedding_sum(embedding=test_embedding)': field( - 'float32', signal=embedding_sum_signal.dict()), - 'test_embedding': field( - signal=TestEmbedding().dict(), - fields=[field('string_span', fields={EMBEDDING_KEY: 'embedding'})]), - }), - }), - num_items=2) - - result = dataset.select_rows(combine_columns=True) - expected_result = [{ - 'text': enriched_item('hello.', {'test_embedding_sum(embedding=test_embedding)': 1.0}) - }, { - 'text': enriched_item('hello2.', {'test_embedding_sum(embedding=test_embedding)': 2.0}) - }] - assert list(result) == expected_result - - -def test_missing_embedding_signal(make_test_data: TestDataMaker, mocker: MockerFixture) -> None: - dataset = make_test_data([{ - 'text': 'hello.', - }, { - 'text': 'hello2.', - }]) - - # The embedding is missing for 'text'. - embedding_sum_signal = TestEmbeddingSumSignal(embedding=TestEmbedding.name) - with pytest.raises(ValueError, match="No embedding found for path \\('text',\\)"): - dataset.compute_signal(embedding_sum_signal, 'text') - - -ENTITY_REGEX = r'[A-Za-z]+@[A-Za-z]+' - - -class NamedEntity(TextSignal): - """Find special entities.""" - name = 'entity' - - @override - def fields(self) -> Field: - return field(fields=['string_span']) - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[List[Item]]]: - for text in data: - if not isinstance(text, str): - yield None - continue - yield [lilac_span(m.start(0), m.end(0)) for m in re.finditer(ENTITY_REGEX, text)] - - -def test_entity_on_split_signal(make_test_data: TestDataMaker) -> None: - text = 'Hello nik@test. Here are some other entities like pii@gmail and all@lilac.' - dataset = make_test_data([{'text': text}]) - entity = NamedEntity() - dataset.compute_signal(TestSplitter(), 'text') - dataset.compute_signal(entity, ('text', 'test_splitter', '*')) - - result = dataset.select_rows(['text'], combine_columns=True) - assert list(result) == [{ - 'text': enriched_item( - text, { - 'test_splitter': [ - lilac_span(0, 15, {'entity': [lilac_span(6, 14)]}), - lilac_span(16, 74, {'entity': [ - lilac_span(50, 59), - lilac_span(64, 73), - ]}), - ] - }) - }] diff --git a/lilac/data/dataset_compute_signal_test.py b/lilac/data/dataset_compute_signal_test.py deleted file mode 100644 index ec8e4a06c4f686a134c57573c03dc79d8923ec80..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_compute_signal_test.py +++ /dev/null @@ -1,572 +0,0 @@ -"""Tests for dataset.compute_signal().""" - -from typing import Iterable, Optional, Union, cast - -import numpy as np -import pytest -from typing_extensions import override - -from ..concepts.concept import ExampleIn -from ..concepts.db_concept import ConceptUpdate, DiskConceptDB -from ..schema import ( - EMBEDDING_KEY, - Field, - Item, - RichData, - SignalInputType, - field, - lilac_embedding, - lilac_span, - schema, -) -from ..signal import ( - TextEmbeddingSignal, - TextSignal, - TextSplitterSignal, - clear_signal_registry, - register_signal, -) -from ..signals.concept_scorer import ConceptSignal -from .dataset import Column, DatasetManifest, GroupsSortBy, SortOrder -from .dataset_test_utils import TEST_DATASET_NAME, TEST_NAMESPACE, TestDataMaker, enriched_item - -SIMPLE_ITEMS: list[Item] = [{ - 'str': 'a', - 'int': 1, - 'bool': False, - 'float': 3.0 -}, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 2.0 -}, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 1.0 -}] - - -class TestInvalidSignal(TextSignal): - name = 'test_invalid_signal' - - @override - def fields(self) -> Field: - return field('int32') - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - # Return an invalid output that doesn't match the input length. - return [] - - -class TestSparseSignal(TextSignal): - name = 'test_sparse_signal' - - @override - def fields(self) -> Field: - return field('int32') - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text in data: - if text == 'hello': - # Skip this input. - yield None - else: - yield len(text) - - -class TestSparseRichSignal(TextSignal): - """Find personally identifiable information (emails, phone numbers, etc).""" - name = 'test_sparse_rich_signal' - - @override - def fields(self) -> Field: - return field(fields={'emails': ['string']}) - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text in data: - if text == 'hello': - # Skip this input. - yield None - else: - yield {'emails': ['test1@hello.com', 'test2@hello.com']} - - -class TestParamSignal(TextSignal): - name = 'param_signal' - param: str - - def fields(self) -> Field: - return field('string') - - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text_content in data: - yield f'{str(text_content)}_{self.param}' - - -class TestSignal(TextSignal): - name = 'test_signal' - - @override - def fields(self) -> Field: - return field(fields={'len': 'int32', 'flen': 'float32'}) - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data] - - -class TestSplitSignal(TextSplitterSignal): - """Split documents into sentence by splitting on period, generating entities.""" - name = 'test_split' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - for text in data: - if not isinstance(text, str): - raise ValueError(f'Expected text to be a string, got {type(text)} instead.') - sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence] - yield [ - lilac_span(text.index(sentence), - text.index(sentence) + len(sentence)) for sentence in sentences - ] - - -EMBEDDINGS: list[tuple[str, Union[list[float], list[list[float]]]]] = [ - ('hello.', [1.0, 0.0, 0.0]), - # This embedding has an outer dimension of 1. - ('hello2.', [[1.0, 1.0, 0.0]]), - ('hello3.', [[0, 0, 1.]]) -] - -STR_EMBEDDINGS: dict[str, Union[list[float], list[list[float]]]] = { - text: embedding for text, embedding in EMBEDDINGS -} - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Call the embedding function.""" - for example in data: - example = cast(str, example) - yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[example]))] - - -class ComputedKeySignal(TextSignal): - name = 'computed_key' - - @override - def fields(self) -> Field: - return field('int64') - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text in data: - yield 1 - - def key(self, is_computed_signal: Optional[bool] = False) -> str: - return f'key_{is_computed_signal}' - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - clear_signal_registry() - register_signal(TestSparseSignal) - register_signal(TestSparseRichSignal) - register_signal(TestParamSignal) - register_signal(TestSignal) - register_signal(TestSplitSignal) - register_signal(TestEmbedding) - register_signal(ComputedKeySignal) - register_signal(ConceptSignal) - - # Unit test runs. - yield - # Teardown. - clear_signal_registry() - - -def test_signal_output_validation(make_test_data: TestDataMaker) -> None: - signal = TestInvalidSignal() - - dataset = make_test_data([{ - 'text': 'hello', - }, { - 'text': 'hello world', - }]) - - with pytest.raises( - ValueError, match='The signal generated 0 values but the input data had 2 values.'): - dataset.compute_signal(signal, 'text') - - -def test_sparse_signal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello', - }, { - 'text': 'hello world', - }]) - - dataset.compute_signal(TestSparseSignal(), 'text') - - result = dataset.select_rows(['text'], combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('hello', {'test_sparse_signal': None}) - }, { - 'text': enriched_item('hello world', {'test_sparse_signal': 11}) - }] - - -def test_sparse_rich_signal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello', - }, { - 'text': 'hello world', - }]) - - dataset.compute_signal(TestSparseRichSignal(), 'text') - - result = dataset.select_rows(['text'], combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('hello', {'test_sparse_rich_signal': None}) - }, { - 'text': enriched_item( - 'hello world', - {'test_sparse_rich_signal': { - 'emails': ['test1@hello.com', 'test2@hello.com'] - }}) - }] - - -def test_source_joined_with_signal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(SIMPLE_ITEMS) - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'str': 'string', - 'int': 'int32', - 'bool': 'boolean', - 'float': 'float32', - }), - num_items=3) - - test_signal = TestSignal() - dataset.compute_signal(test_signal, 'str') - - # Check the enriched dataset manifest has 'text' enriched. - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'str': field( - 'string', - fields={ - 'test_signal': field( - signal=test_signal.dict(), fields={ - 'len': 'int32', - 'flen': 'float32' - }), - }), - 'int': 'int32', - 'bool': 'boolean', - 'float': 'float32', - }), - num_items=3) - - result = dataset.select_rows(['str'], combine_columns=True) - assert list(result) == [{ - 'str': enriched_item('a', {'test_signal': { - 'len': 1, - 'flen': 1.0 - }}), - }, { - 'str': enriched_item('b', {'test_signal': { - 'len': 1, - 'flen': 1.0 - }}), - }, { - 'str': enriched_item('b', {'test_signal': { - 'len': 1, - 'flen': 1.0 - }}), - }] - - # Select a specific signal leaf test_signal.flen with 'str'. - result = dataset.select_rows(['str', ('str', 'test_signal', 'flen')]) - - assert list(result) == [{ - 'str': 'a', - 'str.test_signal.flen': 1.0 - }, { - 'str': 'b', - 'str.test_signal.flen': 1.0 - }, { - 'str': 'b', - 'str.test_signal.flen': 1.0 - }] - - # Select multiple signal leafs with aliasing. - result = dataset.select_rows([ - 'str', - Column(('str', 'test_signal', 'flen'), alias='flen'), - Column(('str', 'test_signal', 'len'), alias='len') - ]) - - assert list(result) == [{ - 'str': 'a', - 'flen': 1.0, - 'len': 1 - }, { - 'str': 'b', - 'flen': 1.0, - 'len': 1 - }, { - 'str': 'b', - 'flen': 1.0, - 'len': 1 - }] - - -def test_parameterized_signal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) - test_signal_a = TestParamSignal(param='a') - test_signal_b = TestParamSignal(param='b') - dataset.compute_signal(test_signal_a, 'text') - dataset.compute_signal(test_signal_b, 'text') - - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'text': field( - 'string', - fields={ - 'param_signal(param=a)': field('string', test_signal_a.dict()), - 'param_signal(param=b)': field('string', test_signal_b.dict()), - }), - }), - num_items=2) - - result = dataset.select_rows(['text'], combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('hello', { - 'param_signal(param=a)': 'hello_a', - 'param_signal(param=b)': 'hello_b', - }) - }, { - 'text': enriched_item('everybody', { - 'param_signal(param=a)': 'everybody_a', - 'param_signal(param=b)': 'everybody_b', - }) - }] - - -def test_split_signal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': '[1, 1] first sentence. [1, 1] second sentence.', - }, { - 'text': 'b2 [2, 1] first sentence. [2, 1] second sentence.', - }]) - - signal = TestSplitSignal() - dataset.compute_signal(signal, 'text') - - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'text': field( - 'string', fields={'test_split': field(signal=signal.dict(), fields=[field('string_span')])}) - }), - num_items=2) - - result = dataset.select_rows(['text'], combine_columns=True) - expected_result = [{ - 'text': enriched_item('[1, 1] first sentence. [1, 1] second sentence.', - {'test_split': [lilac_span(0, 22), lilac_span(23, 46)]}) - }, { - 'text': enriched_item('b2 [2, 1] first sentence. [2, 1] second sentence.', - {'test_split': [ - lilac_span(0, 25), - lilac_span(26, 49), - ]}) - }] - assert list(result) == expected_result - - -def test_signal_on_repeated_field(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': ['hello', 'everybody'], - }, { - 'text': ['hello2', 'everybody2'], - }]) - test_signal = TestSignal() - # Run the signal on the repeated field. - dataset.compute_signal(test_signal, ('text', '*')) - - # Check the enriched dataset manifest has 'text' enriched. - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'text': field(fields=[ - field( - 'string', - fields={ - 'test_signal': field( - signal=test_signal.dict(), fields={ - 'len': 'int32', - 'flen': 'float32' - }) - }) - ]) - }), - num_items=2) - - result = dataset.select_rows([('text', '*')], combine_columns=True) - - assert list(result) == [{ - 'text': [ - enriched_item('hello', {'test_signal': { - 'len': 5, - 'flen': 5.0 - }}), - enriched_item('everybody', {'test_signal': { - 'len': 9, - 'flen': 9.0 - }}) - ] - }, { - 'text': [ - enriched_item('hello2', {'test_signal': { - 'len': 6, - 'flen': 6.0 - }}), - enriched_item('everybody2', {'test_signal': { - 'len': 10, - 'flen': 10.0 - }}) - ] - }] - - -def test_text_splitter(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': '[1, 1] first sentence. [1, 1] second sentence.', - }, { - 'text': 'b2 [2, 1] first sentence. [2, 1] second sentence.', - }]) - - dataset.compute_signal(TestSplitSignal(), 'text') - - result = dataset.select_rows(['text'], combine_columns=True) - expected_result = [{ - 'text': enriched_item('[1, 1] first sentence. [1, 1] second sentence.', - {'test_split': [ - lilac_span(0, 22), - lilac_span(23, 46), - ]}), - }, { - 'text': enriched_item('b2 [2, 1] first sentence. [2, 1] second sentence.', - {'test_split': [ - lilac_span(0, 25), - lilac_span(26, 49), - ]}), - }] - assert list(result) == expected_result - - -def test_embedding_signal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'hello.'}, {'text': 'hello2.'}]) - - embedding_signal = TestEmbedding() - dataset.compute_signal(embedding_signal, 'text') - - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'text': field( - 'string', - fields={ - 'test_embedding': field( - signal=embedding_signal.dict(), - fields=[field('string_span', fields={EMBEDDING_KEY: 'embedding'})]) - }), - }), - num_items=2) - - result = dataset.select_rows(combine_columns=True) - expected_result = [{'text': 'hello.'}, {'text': 'hello2.'}] - assert list(result) == expected_result - - -def test_is_computed_signal_key(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'hello.'}, {'text': 'hello2.'}]) - - signal = ComputedKeySignal() - dataset.compute_signal(signal, 'text') - - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'text': field('string', fields={'key_True': field('int64', signal=signal.dict())}), - }), - num_items=2) - - result = dataset.select_rows(combine_columns=True) - - expected_result = [{ - 'text': enriched_item('hello.', {'key_True': 1}) - }, { - 'text': enriched_item('hello2.', {'key_True': 1}) - }] - assert list(result) == expected_result - - -def test_concept_signal_with_select_groups(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello.', - }, { - 'text': 'hello2.', - }, { - 'text': 'hello3.', - }]) - - embedding_signal = TestEmbedding() - dataset.compute_signal(embedding_signal, 'text') - - concept_db = DiskConceptDB() - concept_db.create(namespace='test_namespace', name='test_concept', type=SignalInputType.TEXT) - concept_db.edit( - 'test_namespace', 'test_concept', - ConceptUpdate(insert=[ - ExampleIn(label=False, text='hello.'), - ExampleIn(label=True, text='hello2.'), - ExampleIn(label=False, text='hello3.') - ])) - - dataset.compute_concept( - namespace='test_namespace', - concept_name='test_concept', - embedding='test_embedding', - path='text') - - concept_key = 'test_namespace/test_concept/test_embedding/v1' - result = dataset.select_groups(f'text.{concept_key}.*.score') - assert result.counts == [('Not in concept', 2), ('In concept', 1)] - - result = dataset.select_groups( - f'text.{concept_key}.*.score', sort_by=GroupsSortBy.COUNT, sort_order=SortOrder.ASC) - assert result.counts == [('In concept', 1), ('Not in concept', 2)] diff --git a/lilac/data/dataset_config_test.py b/lilac/data/dataset_config_test.py deleted file mode 100644 index b4bcf420c195d312ff55d5c97d3320bc6a5c2124..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_config_test.py +++ /dev/null @@ -1,216 +0,0 @@ -"""Tests for dataset.config().""" - -from typing import Iterable, Optional - -import numpy as np -import pytest -from typing_extensions import override - -from ..config import ( - DatasetConfig, - DatasetSettings, - DatasetUISettings, - EmbeddingConfig, - SignalConfig, -) -from ..schema import Field, Item, RichData, field, lilac_embedding -from ..signal import TextEmbeddingSignal, TextSignal, clear_signal_registry, register_signal -from .dataset_test_utils import TestDataMaker, TestSource - - -class TestSignal(TextSignal): - name = 'test_signal' - - @override - def fields(self) -> Field: - return field(fields={'len': 'int32'}) - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - return [{'len': len(text_content)} for text_content in data] - - -class TestSignal2(TextSignal): - name = 'test_signal2' - - @override - def fields(self) -> Field: - return field(fields={'len2': 'int32'}) - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - return [{'len2': len(text_content)} for text_content in data] - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Call the embedding function.""" - for example in data: - yield [lilac_embedding(0, len(example), np.array([1.]))] - - -class TestEmbedding2(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding2' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Call the embedding function.""" - for example in data: - yield [lilac_embedding(0, len(example), np.array([2.]))] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - clear_signal_registry() - register_signal(TestEmbedding) - register_signal(TestEmbedding2) - - # Unit test runs. - yield - # Teardown. - clear_signal_registry() - - -def test_config_compute_signal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello', - }, { - 'text': 'hello world' - }]) - - assert dataset.config() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=TestSource(), - # 'text' is the longest path, so should be set as the default setting. - settings=DatasetSettings(ui=DatasetUISettings(media_paths=[('text',)]))).dict() - - dataset.compute_signal(TestSignal(), 'text') - - assert dataset.config() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=TestSource(), - signals=[SignalConfig( - path=('text',), - signal=TestSignal(), - )], - settings=DatasetSettings(ui=DatasetUISettings(media_paths=[('text',)]))).dict() - - # Computing the same signal again should not change the config. - dataset.compute_signal(TestSignal(), 'text') - - assert dataset.config() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=TestSource(), - signals=[SignalConfig( - path=('text',), - signal=TestSignal(), - )], - settings=DatasetSettings(ui=DatasetUISettings(media_paths=[('text',)]))).dict() - - # Computing another signal should add another config. - dataset.compute_signal(TestSignal2(), 'text') - - assert dataset.config() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=TestSource(), - signals=[ - SignalConfig( - path=('text',), - signal=TestSignal(), - ), - SignalConfig( - path=('text',), - signal=TestSignal2(), - ) - ], - settings=DatasetSettings(ui=DatasetUISettings(media_paths=[('text',)]))).dict() - - -def test_config_compute_embedding(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'hello world'}]) - - assert dataset.config() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=TestSource(), - # 'text' is the longest path, so should be set as the default setting. - settings=DatasetSettings(ui=DatasetUISettings(media_paths=[('text',)]))).dict() - - dataset.compute_embedding('test_embedding', 'text') - - assert dataset.config() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=TestSource(), - embeddings=[EmbeddingConfig( - path=('text',), - embedding='test_embedding', - )], - settings=DatasetSettings(ui=DatasetUISettings(media_paths=[('text',)]))).dict() - - # Computing the same embedding again should not change the config. - dataset.compute_embedding('test_embedding', 'text') - - assert dataset.config() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=TestSource(), - embeddings=[EmbeddingConfig( - path=('text',), - embedding='test_embedding', - )], - settings=DatasetSettings(ui=DatasetUISettings(media_paths=[('text',)]))).dict() - - # Computing another embedding should add another config. - dataset.compute_embedding('test_embedding2', 'text') - - assert dataset.config() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=TestSource(), - embeddings=[ - EmbeddingConfig( - path=('text',), - embedding='test_embedding', - ), - EmbeddingConfig( - path=('text',), - embedding='test_embedding2', - ) - ], - settings=DatasetSettings(ui=DatasetUISettings(media_paths=[('text',)]))).dict() - - -def test_settings(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'hello world'}]) - expected_settings = DatasetSettings(ui=DatasetUISettings(media_paths=[('text',)])) - - # Settings is reflected in the config and the public settings method. - assert dataset.config() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=TestSource(), - settings=expected_settings).dict() - - assert dataset.settings() == expected_settings - - # Settings can only be updated through the public method for updating settings. - dataset.update_settings(DatasetSettings(ui=DatasetUISettings(media_paths=[('str',)]))) - - expected_settings = DatasetSettings(ui=DatasetUISettings(media_paths=[('str',)])) - assert dataset.settings() == expected_settings - assert dataset.config() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=TestSource(), - settings=expected_settings).dict() diff --git a/lilac/data/dataset_export_test.py b/lilac/data/dataset_export_test.py deleted file mode 100644 index 8ed5ea3782f19f38e650092f775a1187c307edea..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_export_test.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Implementation-agnostic tests for exporting a dataset.""" - -import csv -import json -import pathlib -from typing import Iterable, Optional - -import numpy as np -import pandas as pd -import pytest -from typing_extensions import override - -from ..schema import ROWID, Field, Item, RichData, field -from ..signal import TextSignal, clear_signal_registry, register_signal -from .dataset_test_utils import TestDataMaker - - -class TestSignal(TextSignal): - name = 'test_signal' - - @override - def fields(self) -> Field: - return field(fields={'len': 'int32', 'flen': 'float32'}) - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - clear_signal_registry() - register_signal(TestSignal) - - yield # Unit test runs. - clear_signal_registry() # Teardown. - - -def test_export_to_json(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) - dataset.compute_signal(TestSignal(), 'text') - - # Download all columns. - filepath = tmp_path / 'dataset.json' - dataset.to_json(filepath) - - with open(filepath) as f: - parsed_items = [json.loads(line) for line in f.readlines()] - - assert parsed_items == [{ - 'text': 'hello', - 'text.test_signal': { - 'len': 5, - 'flen': 5.0 - } - }, { - 'text': 'everybody', - 'text.test_signal': { - 'len': 9, - 'flen': 9.0 - } - }] - - -def test_export_to_csv(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) - dataset.compute_signal(TestSignal(), 'text') - - # Download all columns. - filepath = tmp_path / 'dataset.csv' - dataset.to_csv(filepath) - - with open(filepath) as f: - rows = list(csv.reader(f)) - - assert rows == [ - ['text', 'text.test_signal'], - ['hello', "{'len': 5, 'flen': 5.0}"], - ['everybody', "{'len': 9, 'flen': 9.0}"], - ] - - -def test_export_to_parquet(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) - dataset.compute_signal(TestSignal(), 'text') - - # Download all columns. - filepath = tmp_path / 'dataset.parquet' - dataset.to_parquet(filepath) - - df = pd.read_parquet(filepath) - expected_df = pd.DataFrame([{ - 'text': 'hello', - 'text.test_signal': { - 'len': 5, - 'flen': 5.0 - } - }, { - 'text': 'everybody', - 'text.test_signal': { - 'len': 9, - 'flen': 9.0 - } - }]) - pd.testing.assert_frame_equal(df, expected_df) - - -def test_export_to_pandas(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) - dataset.compute_signal(TestSignal(), 'text') - - # Download all columns. - df = dataset.to_pandas() - expected_df = pd.DataFrame([{ - 'text': 'hello', - 'text.test_signal': { - 'len': 5, - 'flen': 5.0 - } - }, { - 'text': 'everybody', - 'text.test_signal': { - 'len': 9, - 'flen': 9.0 - } - }]) - pd.testing.assert_frame_equal(df, expected_df) - - # Select only some columns, including pseudocolumn rowid. - df = dataset.to_pandas([ROWID, 'text', 'text.test_signal.flen']) - expected_df = pd.DataFrame([{ - ROWID: '1', - 'text': 'hello', - 'text.test_signal.flen': np.float32(5.0) - }, { - ROWID: '2', - 'text': 'everybody', - 'text.test_signal.flen': np.float32(9.0) - }]) - pd.testing.assert_frame_equal(df, expected_df) - - # Invalid columns. - with pytest.raises(ValueError, match="Unable to select path \\('text', 'test_signal2'\\)"): - dataset.to_pandas(['text', 'text.test_signal2']) diff --git a/lilac/data/dataset_select_groups_test.py b/lilac/data/dataset_select_groups_test.py deleted file mode 100644 index d466f06a4595747e26f8aa58276ec5cb13e42849..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_select_groups_test.py +++ /dev/null @@ -1,345 +0,0 @@ -"""Tests for dataset.select_groups().""" - -import re -from datetime import datetime - -import pytest -from pytest_mock import MockerFixture - -from ..schema import Item, field, schema -from . import dataset as dataset_module -from .dataset_test_utils import TestDataMaker - - -def test_flat_data(make_test_data: TestDataMaker) -> None: - items: list[Item] = [ - { - 'name': 'Name1', - 'age': 34, - 'active': False - }, - { - 'name': 'Name2', - 'age': 45, - 'active': True - }, - { - 'age': 17, - 'active': True - }, # Missing "name". - { - 'name': 'Name3', - 'active': True - }, # Missing "age". - { - 'name': 'Name4', - 'age': 55 - } # Missing "active". - ] - dataset = make_test_data(items) - - result = dataset.select_groups(leaf_path='name') - assert result.counts == [('Name1', 1), ('Name2', 1), (None, 1), ('Name3', 1), ('Name4', 1)] - - result = dataset.select_groups(leaf_path='age', bins=[20, 50, 60]) - assert result.counts == [('1', 2), ('0', 1), (None, 1), ('2', 1)] - - result = dataset.select_groups(leaf_path='active') - assert result.counts == [ - (True, 3), - (False, 1), - (None, 1), # Missing "active". - ] - - -def test_result_counts(make_test_data: TestDataMaker) -> None: - items: list[Item] = [ - { - 'active': False - }, - { - 'active': True - }, - { - 'active': True - }, - { - 'active': True - }, - {} # Missing "active". - ] - dataset = make_test_data(items, schema=schema({'active': 'boolean'})) - - result = dataset.select_groups(leaf_path='active') - assert result.counts == [(True, 3), (False, 1), (None, 1)] - - -def test_list_of_structs(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{ - 'list_of_structs': [{ - 'name': 'a' - }, { - 'name': 'b' - }] - }, { - 'list_of_structs': [{ - 'name': 'c' - }, { - 'name': 'a' - }, { - 'name': 'd' - }] - }, { - 'list_of_structs': [{ - 'name': 'd' - }] - }] - dataset = make_test_data(items) - - result = dataset.select_groups(leaf_path='list_of_structs.*.name') - assert result.counts == [('a', 2), ('d', 2), ('b', 1), ('c', 1)] - - -def test_nested_lists(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{ - 'nested_list': [[{ - 'name': 'a' - }], [{ - 'name': 'b' - }]] - }, { - 'nested_list': [[{ - 'name': 'c' - }, { - 'name': 'a' - }], [{ - 'name': 'd' - }]] - }, { - 'nested_list': [[{ - 'name': 'd' - }]] - }] - dataset = make_test_data(items) - - result = dataset.select_groups(leaf_path='nested_list.*.*.name') - assert result.counts == [('a', 2), ('d', 2), ('b', 1), ('c', 1)] - - -def test_nested_struct(make_test_data: TestDataMaker) -> None: - items: list[Item] = [ - { - 'nested_struct': { - 'struct': { - 'name': 'c' - } - } - }, - { - 'nested_struct': { - 'struct': { - 'name': 'b' - } - } - }, - { - 'nested_struct': { - 'struct': { - 'name': 'a' - } - } - }, - ] - dataset = make_test_data(items) - - result = dataset.select_groups(leaf_path='nested_struct.struct.name') - assert result.counts == [('c', 1), ('b', 1), ('a', 1)] - - -def test_named_bins(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{ - 'age': 34, - }, { - 'age': 45, - }, { - 'age': 17, - }, { - 'age': 80 - }, { - 'age': 55 - }, { - 'age': float('nan') - }] - dataset = make_test_data(items) - - result = dataset.select_groups( - leaf_path='age', - bins=[ - ('young', None, 20), - ('adult', 20, 50), - ('middle-aged', 50, 65), - ('senior', 65, None), - ]) - assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1), (None, 1)] - - -def test_schema_with_bins(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{ - 'age': 34, - }, { - 'age': 45, - }, { - 'age': 17, - }, { - 'age': 80 - }, { - 'age': 55 - }, { - 'age': float('nan') - }] - data_schema = schema({ - 'age': field( - 'float32', - bins=[ - ('young', None, 20), - ('adult', 20, 50), - ('middle-aged', 50, 65), - ('senior', 65, None), - ]) - }) - dataset = make_test_data(items, data_schema) - - result = dataset.select_groups(leaf_path='age') - assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1), (None, 1)] - - -def test_filters(make_test_data: TestDataMaker) -> None: - items: list[Item] = [ - { - 'name': 'Name1', - 'age': 34, - 'active': False - }, - { - 'name': 'Name2', - 'age': 45, - 'active': True - }, - { - 'age': 17, - 'active': True - }, # Missing "name". - { - 'name': 'Name3', - 'active': True - }, # Missing "age". - { - 'name': 'Name4', - 'age': 55 - } # Missing "active". - ] - dataset = make_test_data(items) - - # active = True. - result = dataset.select_groups(leaf_path='name', filters=[('active', 'equals', True)]) - assert result.counts == [('Name2', 1), (None, 1), ('Name3', 1)] - - # age < 35. - result = dataset.select_groups(leaf_path='name', filters=[('age', 'less', 35)]) - assert result.counts == [('Name1', 1), (None, 1)] - - # age < 35 and active = True. - result = dataset.select_groups( - leaf_path='name', filters=[('age', 'less', 35), ('active', 'equals', True)]) - assert result.counts == [(None, 1)] - - -def test_datetime(make_test_data: TestDataMaker) -> None: - items: list[Item] = [ - { - 'id': 1, - 'date': datetime(2023, 1, 1) - }, - { - 'id': 2, - 'date': datetime(2023, 1, 15) - }, - { - 'id': 3, - 'date': datetime(2023, 2, 1) - }, - { - 'id': 4, - 'date': datetime(2023, 3, 1) - }, - { - 'id': 5, - # Missing datetime. - } - ] - dataset = make_test_data(items) - result = dataset.select_groups('date') - assert result.counts == [(datetime(2023, 1, 1), 1), (datetime(2023, 1, 15), 1), - (datetime(2023, 2, 1), 1), (datetime(2023, 3, 1), 1), (None, 1)] - - -def test_invalid_leaf(make_test_data: TestDataMaker) -> None: - items: list[Item] = [ - { - 'nested_struct': { - 'struct': { - 'name': 'c' - } - } - }, - { - 'nested_struct': { - 'struct': { - 'name': 'b' - } - } - }, - { - 'nested_struct': { - 'struct': { - 'name': 'a' - } - } - }, - ] - dataset = make_test_data(items) - - with pytest.raises( - ValueError, match=re.escape("Leaf \"('nested_struct',)\" not found in dataset")): - dataset.select_groups(leaf_path='nested_struct') - - with pytest.raises( - ValueError, match=re.escape("Leaf \"('nested_struct', 'struct')\" not found in dataset")): - dataset.select_groups(leaf_path='nested_struct.struct') - - with pytest.raises( - ValueError, - match=re.escape("Path ('nested_struct', 'struct', 'wrong_name') not found in schema")): - dataset.select_groups(leaf_path='nested_struct.struct.wrong_name') - - -def test_too_many_distinct(make_test_data: TestDataMaker, mocker: MockerFixture) -> None: - too_many_distinct = 5 - mocker.patch(f'{dataset_module.__name__}.TOO_MANY_DISTINCT', too_many_distinct) - - items: list[Item] = [{'feature': str(i)} for i in range(too_many_distinct + 10)] - dataset = make_test_data(items) - - res = dataset.select_groups('feature') - assert res.too_many_distinct is True - assert res.counts == [] - - -def test_auto_bins_for_float(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{'feature': float(i)} for i in range(5)] + [{'feature': float('nan')}] - dataset = make_test_data(items) - - res = dataset.select_groups('feature') - assert res.counts == [('0', 1), ('3', 1), ('7', 1), ('11', 1), ('14', 1), (None, 1)] - assert res.too_many_distinct is False - assert res.bins diff --git a/lilac/data/dataset_select_rows_filter_test.py b/lilac/data/dataset_select_rows_filter_test.py deleted file mode 100644 index 1369dd18074cb13ef49aa8471acb010cc9e06130..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_select_rows_filter_test.py +++ /dev/null @@ -1,185 +0,0 @@ -"""Tests for dataset.select_rows(filters=[...]).""" - -import pytest - -from ..schema import ROWID, Item, schema -from .dataset import BinaryFilterTuple, ListFilterTuple, UnaryFilterTuple -from .dataset_test_utils import TestDataMaker - -TEST_DATA: list[Item] = [{ - 'str': 'a', - 'int': 1, - 'bool': False, - 'float': 3.0 -}, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 2.0 -}, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 1.0 -}, { - 'float': float('nan') -}] - - -def test_filter_by_ids(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - id_filter: BinaryFilterTuple = (ROWID, 'equals', '1') - result = dataset.select_rows(filters=[id_filter]) - - assert list(result) == [{'str': 'a', 'int': 1, 'bool': False, 'float': 3.0}] - - id_filter = (ROWID, 'equals', '2') - result = dataset.select_rows(filters=[id_filter]) - - assert list(result) == [{'str': 'b', 'int': 2, 'bool': True, 'float': 2.0}] - - id_filter = (ROWID, 'equals', b'f') - result = dataset.select_rows(filters=[id_filter]) - - assert list(result) == [] - - -def test_filter_greater(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - filter: BinaryFilterTuple = ('float', 'greater', 2.0) - result = dataset.select_rows(filters=[filter]) - - assert list(result) == [{'str': 'a', 'int': 1, 'bool': False, 'float': 3.0}] - - -def test_filter_greater_equal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - filter: BinaryFilterTuple = ('float', 'greater_equal', 2.0) - result = dataset.select_rows(filters=[filter]) - - assert list(result) == [{ - 'str': 'a', - 'int': 1, - 'bool': False, - 'float': 3.0 - }, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 2.0 - }] - - -def test_filter_less(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - filter: BinaryFilterTuple = ('float', 'less', 2.0) - result = dataset.select_rows(['*'], filters=[filter]) - - assert list(result) == [{'str': 'b', 'int': 2, 'bool': True, 'float': 1.0}] - - -def test_filter_less_equal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - filter: BinaryFilterTuple = ('float', 'less_equal', 2.0) - result = dataset.select_rows(filters=[filter]) - - assert list(result) == [{ - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 2.0 - }, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 1.0 - }] - - -def test_filter_not_equal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - filter: BinaryFilterTuple = ('float', 'not_equal', 2.0) - result = dataset.select_rows(filters=[filter]) - - assert list(result) == [ - { - 'str': 'a', - 'int': 1, - 'bool': False, - 'float': 3.0 - }, - { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 1.0 - }, - # NaNs are not counted when we are filtering a field. - ] - - -def test_filter_by_list_of_ids(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - filter: ListFilterTuple = (ROWID, 'in', ['1', '2']) - result = dataset.select_rows(['*', ROWID], filters=[filter]) - - assert list(result) == [{ - ROWID: '1', - 'str': 'a', - 'int': 1, - 'bool': False, - 'float': 3.0 - }, { - ROWID: '2', - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 2.0 - }] - - -def test_filter_by_exists(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{ - 'name': 'A', - 'info': { - 'lang': 'en' - }, - 'ages': [] - }, { - 'info': { - 'lang': 'fr' - }, - }, { - 'name': 'C', - 'ages': [[1, 2], [3, 4]] - }] - dataset = make_test_data( - items, schema=schema({ - 'name': 'string', - 'info': { - 'lang': 'string' - }, - 'ages': [['int32']] - })) - - exists_filter: UnaryFilterTuple = ('name', 'exists') - result = dataset.select_rows(['name'], filters=[exists_filter]) - assert list(result) == [{'name': 'A'}, {'name': 'C'}] - - exists_filter = ('info.lang', 'exists') - result = dataset.select_rows(['name'], filters=[exists_filter]) - assert list(result) == [{'name': 'A'}, {'name': None}] - - exists_filter = ('ages.*.*', 'exists') - result = dataset.select_rows(['name'], filters=[exists_filter]) - assert list(result) == [{'name': 'C'}] - - with pytest.raises(ValueError, match='Unable to filter on path'): - dataset.select_rows(['name'], filters=[('info', 'exists')]) diff --git a/lilac/data/dataset_select_rows_schema_test.py b/lilac/data/dataset_select_rows_schema_test.py deleted file mode 100644 index e816236cd02208f3c73457c94d1a8169f48630c1..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_select_rows_schema_test.py +++ /dev/null @@ -1,530 +0,0 @@ -"""Tests for `db.select_rows_schema()`.""" - -from typing import Iterable, Optional, cast - -import numpy as np -import pytest -from typing_extensions import override - -from ..embeddings.vector_store import VectorDBIndex -from ..schema import ( - EMBEDDING_KEY, - PATH_WILDCARD, - Field, - Item, - RichData, - SignalInputType, - VectorKey, - field, - lilac_embedding, - lilac_span, - schema, -) -from ..signal import ( - TextEmbeddingSignal, - TextSignal, - TextSplitterSignal, - VectorSignal, - clear_signal_registry, - register_signal, -) -from ..signals.concept_labels import ConceptLabelsSignal -from ..signals.concept_scorer import ConceptSignal -from ..signals.semantic_similarity import SemanticSimilaritySignal -from ..signals.substring_search import SubstringSignal -from .dataset import ( - Column, - ConceptSearch, - KeywordSearch, - SearchResultInfo, - SelectRowsSchemaResult, - SelectRowsSchemaUDF, - SemanticSearch, - SortOrder, - SortResult, -) -from .dataset_test_utils import TestDataMaker - -TEST_DATA: list[Item] = [{ - 'erased': False, - 'people': [{ - 'name': 'A', - 'zipcode': 0, - 'locations': [{ - 'city': 'city1', - 'state': 'state1' - }, { - 'city': 'city2', - 'state': 'state2' - }] - }] -}, { - 'erased': True, - 'people': [{ - 'name': 'B', - 'zipcode': 1, - 'locations': [{ - 'city': 'city3', - 'state': 'state3' - }, { - 'city': 'city4' - }, { - 'city': 'city5' - }] - }, { - 'name': 'C', - 'zipcode': 2, - 'locations': [{ - 'city': 'city1', - 'state': 'state1' - }] - }] -}] - - -class TestSplitter(TextSplitterSignal): - """Split documents into sentence by splitting on period.""" - name = 'test_splitter' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - for text in data: - if not isinstance(text, str): - raise ValueError(f'Expected text to be a string, got {type(text)} instead.') - sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence] - yield [ - lilac_span(text.index(sentence), - text.index(sentence) + len(sentence)) for sentence in sentences - ] - - -EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]), - ('hello2.', [1.0, 1.0, 0.0]), - ('hello world.', [1.0, 1.0, 1.0]), - ('hello world2.', [2.0, 1.0, 1.0])] - -STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Call the embedding function.""" - for example in data: - yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))] - - -class TestEmbeddingSumSignal(VectorSignal): - """Sums the embeddings to return a single floating point value.""" - name = 'test_embedding_sum' - input_type = SignalInputType.TEXT - - @override - def fields(self) -> Field: - return field('float32') - - @override - def vector_compute(self, keys: Iterable[VectorKey], - vector_index: VectorDBIndex) -> Iterable[Item]: - # The signal just sums the values of the embedding. - all_vector_spans = vector_index.get(keys) - for vector_spans in all_vector_spans: - yield vector_spans[0]['vector'].sum() - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_signal(LengthSignal) - register_signal(AddSpaceSignal) - register_signal(TestSplitter) - register_signal(TestEmbedding) - register_signal(TestEmbeddingSumSignal) - - # Unit test runs. - yield - - # Teardown. - clear_signal_registry() - - -class LengthSignal(TextSignal): - name = 'length_signal' - - def fields(self) -> Field: - return field('int32') - - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text_content in data: - yield len(text_content) - - -class AddSpaceSignal(TextSignal): - name = 'add_space_signal' - - def fields(self) -> Field: - return field('string') - - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text_content in data: - yield cast(str, text_content) + ' ' - - -def test_simple_schema(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - result = dataset.select_rows_schema(combine_columns=True) - assert result == SelectRowsSchemaResult( - data_schema=schema({ - 'erased': 'boolean', - 'people': [{ - 'name': 'string', - 'zipcode': 'int32', - 'locations': [{ - 'city': 'string', - 'state': 'string' - }] - }] - })) - - -def test_subselection_with_combine_cols(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - result = dataset.select_rows_schema([('people', '*', 'zipcode'), - ('people', '*', 'locations', '*', 'city')], - combine_columns=True) - assert result == SelectRowsSchemaResult( - data_schema=schema({'people': [{ - 'zipcode': 'int32', - 'locations': [{ - 'city': 'string' - }] - }]})) - - result = dataset.select_rows_schema([('people', '*', 'name'), ('people', '*', 'locations')], - combine_columns=True) - assert result == SelectRowsSchemaResult( - data_schema=schema( - {'people': [{ - 'name': 'string', - 'locations': [{ - 'city': 'string', - 'state': 'string' - }] - }]})) - - result = dataset.select_rows_schema([('people', '*')], combine_columns=True) - assert result == SelectRowsSchemaResult( - data_schema=schema({ - 'people': [{ - 'name': 'string', - 'zipcode': 'int32', - 'locations': [{ - 'city': 'string', - 'state': 'string' - }] - }] - })) - - -def test_udf_with_combine_cols(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - length_signal = LengthSignal() - result = dataset.select_rows_schema([('people', '*', 'locations', '*', 'city'), - Column(('people', '*', 'name'), signal_udf=length_signal)], - combine_columns=True) - assert result == SelectRowsSchemaResult( - data_schema=schema({ - 'people': [{ - 'name': { - 'length_signal': field('int32', length_signal.dict()) - }, - 'locations': [{ - 'city': 'string' - }] - }], - }), - udfs=[ - SelectRowsSchemaUDF(path=('people', '*', 'name', length_signal.key())), - ], - ) - - -def test_embedding_udf_with_combine_cols(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - add_space_signal = AddSpaceSignal() - path = ('people', '*', 'name') - dataset.compute_signal(add_space_signal, path) - result = dataset.select_rows_schema([path, Column(path, signal_udf=add_space_signal)], - combine_columns=True) - assert result == SelectRowsSchemaResult( - data_schema=schema({ - 'people': [{ - 'name': field( - 'string', fields={'add_space_signal': field('string', signal=add_space_signal.dict())}) - }], - }), - udfs=[ - SelectRowsSchemaUDF(path=(*path, add_space_signal.key())), - ], - ) - - -def test_udf_chained_with_combine_cols(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello. hello2.', - }, { - 'text': 'hello world. hello world2.', - }]) - - test_splitter = TestSplitter() - dataset.compute_signal(test_splitter, ('text')) - add_space_signal = AddSpaceSignal() - result = dataset.select_rows_schema( - [('text'), Column(('text'), signal_udf=add_space_signal)], combine_columns=True) - - assert result == SelectRowsSchemaResult( - data_schema=schema({ - 'text': field( - 'string', - fields={ - 'add_space_signal': field('string', add_space_signal.dict()), - 'test_splitter': field(signal=test_splitter.dict(), fields=['string_span']) - }) - }), - udfs=[ - SelectRowsSchemaUDF(path=('text', add_space_signal.key())), - ], - ) - - -def test_udf_embedding_chained_with_combine_cols(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello. hello2.', - }, { - 'text': 'hello world. hello world2.', - }]) - - test_splitter = TestSplitter() - dataset.compute_signal(test_splitter, 'text') - test_embedding = TestEmbedding() - dataset.compute_signal(test_embedding, ('text', 'test_splitter', '*')) - - embedding_sum_signal = TestEmbeddingSumSignal(embedding='test_embedding') - udf_col = Column(('text', 'test_splitter', '*'), signal_udf=embedding_sum_signal) - result = dataset.select_rows_schema([('text'), udf_col], combine_columns=True) - - expected_schema = schema({ - 'text': field( - 'string', - fields={ - 'test_splitter': field( - signal=test_splitter.dict(), - fields=[ - field( - 'string_span', - fields={ - 'test_embedding': field( - signal=test_embedding.dict(), - fields=[field('string_span', fields={EMBEDDING_KEY: 'embedding'})]), - embedding_sum_signal.key(): field('float32', signal=embedding_sum_signal.dict()) - }) - ]) - }) - }) - output_path = ('text', 'test_splitter', '*', embedding_sum_signal.key()) - assert result == SelectRowsSchemaResult( - data_schema=expected_schema, - udfs=[SelectRowsSchemaUDF(path=output_path)], - ) - - # Alias the udf. - udf_col.alias = 'udf1' - result = dataset.select_rows_schema([('text'), udf_col], combine_columns=True) - assert result == SelectRowsSchemaResult( - data_schema=expected_schema, - udfs=[SelectRowsSchemaUDF(path=output_path, alias='udf1')], - ) - - -def test_search_keyword_schema(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello world', - 'text2': 'hello world2', - }]) - query_world = 'world' - query_hello = 'hello' - - result = dataset.select_rows_schema( - searches=[ - KeywordSearch(path='text', query=query_world), - KeywordSearch(path='text2', query=query_hello), - ], - combine_columns=True) - - expected_world_signal = SubstringSignal(query=query_world) - expected_hello_signal = SubstringSignal(query=query_hello) - - assert result == SelectRowsSchemaResult( - data_schema=schema({ - 'text': field( - 'string', - fields={ - expected_world_signal.key(): field( - signal=expected_world_signal.dict(), fields=['string_span']) - }), - 'text2': field( - 'string', - fields={ - expected_hello_signal.key(): field( - signal=expected_hello_signal.dict(), fields=['string_span']) - }) - }), - search_results=[ - SearchResultInfo( - search_path=('text',), - result_path=('text', expected_world_signal.key(), PATH_WILDCARD), - ), - SearchResultInfo( - search_path=('text2',), - result_path=('text2', expected_hello_signal.key(), PATH_WILDCARD), - ) - ], - udfs=[ - SelectRowsSchemaUDF(path=('text', expected_world_signal.key())), - SelectRowsSchemaUDF(path=('text2', expected_hello_signal.key())), - ], - ) - - -def test_search_semantic_schema(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello world.', - }]) - query_world = 'world' - - test_embedding = TestEmbedding() - dataset.compute_signal(test_embedding, ('text')) - - result = dataset.select_rows_schema( - searches=[ - SemanticSearch(path='text', query=query_world, embedding='test_embedding'), - ], - combine_columns=True) - - test_embedding = TestEmbedding() - expected_world_signal = SemanticSimilaritySignal(query=query_world, embedding='test_embedding') - - similarity_score_path = ('text', expected_world_signal.key()) - assert result == SelectRowsSchemaResult( - data_schema=schema({ - 'text': field( - 'string', - fields={ - 'test_embedding': field( - signal=test_embedding.dict(), - fields=[field('string_span', fields={EMBEDDING_KEY: 'embedding'})]), - expected_world_signal.key(): field( - signal=expected_world_signal.dict(), - fields=[field('string_span', fields={'score': 'float32'})]) - }) - }), - udfs=[SelectRowsSchemaUDF(path=similarity_score_path)], - search_results=[SearchResultInfo(search_path=('text',), result_path=similarity_score_path)], - sorts=[ - SortResult( - path=(*similarity_score_path, PATH_WILDCARD, 'score'), order=SortOrder.DESC, search_index=0) - ]) - - -def test_search_concept_schema(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello world.', - }]) - - test_embedding = TestEmbedding() - dataset.compute_signal(test_embedding, ('text')) - - result = dataset.select_rows_schema( - searches=[ - ConceptSearch( - path='text', - concept_namespace='test_namespace', - concept_name='test_concept', - embedding='test_embedding'), - ], - combine_columns=True) - - expected_world_signal = ConceptSignal( - namespace='test_namespace', concept_name='test_concept', embedding='test_embedding') - expected_labels_signal = ConceptLabelsSignal( - namespace='test_namespace', concept_name='test_concept') - - concept_score_path = ('text', expected_world_signal.key()) - concept_labels_path = ('text', expected_labels_signal.key()) - assert result == SelectRowsSchemaResult( - data_schema=schema({ - 'text': field( - 'string', - fields={ - 'test_embedding': field( - signal=test_embedding.dict(), - fields=[field('string_span', fields={EMBEDDING_KEY: 'embedding'})]), - expected_world_signal.key(): field( - signal=expected_world_signal.dict(), - fields=[ - field( - dtype='string_span', - fields={ - 'score': field( - 'float32', - bins=[('Not in concept', None, 0.5), ('In concept', 0.5, None)], - ) - }) - ]), - 'test_namespace/test_concept/labels': field( - fields=[field('string_span', fields={ - 'label': 'boolean', - 'draft': 'string' - })], - signal=expected_labels_signal.dict()) - }) - }), - udfs=[ - SelectRowsSchemaUDF(path=concept_labels_path), - SelectRowsSchemaUDF(path=concept_score_path) - ], - search_results=[ - SearchResultInfo(search_path=('text',), result_path=concept_labels_path), - SearchResultInfo(search_path=('text',), result_path=concept_score_path) - ], - sorts=[ - SortResult( - path=(*concept_score_path, PATH_WILDCARD, 'score'), order=SortOrder.DESC, search_index=0) - ]) - - -def test_search_sort_override(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello world.', - }]) - query_world = 'world' - - test_embedding = TestEmbedding() - dataset.compute_signal(test_embedding, ('text')) - - result = dataset.select_rows_schema( - searches=[ - SemanticSearch(path='text', query=query_world, embedding='test_embedding'), - ], - # Explicit sort by overrides the semantic search. - sort_by=[('text',)], - sort_order=SortOrder.DESC, - combine_columns=True) - - assert result.sorts == [SortResult(path=('text',), order=SortOrder.DESC)] diff --git a/lilac/data/dataset_select_rows_search_test.py b/lilac/data/dataset_select_rows_search_test.py deleted file mode 100644 index 45bfec102d6fa979aaeb3dbf4b564195a1f9c7fb..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_select_rows_search_test.py +++ /dev/null @@ -1,466 +0,0 @@ -"""Tests for dataset.select_rows(searches=[...]).""" - -from typing import Iterable, cast - -import numpy as np -import pytest -from pytest import approx -from pytest_mock import MockerFixture -from typing_extensions import override - -from ..concepts.concept import ExampleIn, LogisticEmbeddingModel -from ..concepts.db_concept import ConceptUpdate, DiskConceptDB -from ..db_manager import set_default_dataset_cls -from ..schema import ROWID, Item, RichData, SignalInputType, lilac_embedding, lilac_span -from ..signal import TextEmbeddingSignal, clear_signal_registry, register_signal -from ..signals.concept_scorer import ConceptSignal -from ..signals.semantic_similarity import SemanticSimilaritySignal -from ..signals.substring_search import SubstringSignal -from .dataset import ConceptSearch, KeywordSearch, SemanticSearch, SortOrder -from .dataset_duckdb import DatasetDuckDB -from .dataset_test_utils import TestDataMaker, enriched_item - -TEST_DATA: list[Item] = [{ - 'text': 'hello world', - 'text2': 'again hello world', -}, { - 'text': 'looking for world in text', - 'text2': 'again looking for world in text', -}, { - 'text': 'unrelated text', - 'text2': 'again unrelated text' -}] - -EMBEDDINGS: list[tuple[str, list[float]]] = [ - ('hello.', [1.0, 0.0, 0.0]), - ('hello2.', [1.0, 1.0, 0.0]), - ('hello world.', [1.0, 1.0, 1.0]), - ('hello world2.', [2.0, 1.0, 1.0]), - ('random negative 1', [0, 0, 0.3]), - ('random negative 2', [0, 0, 0.4]), - ('random negative 3', [0, 0.1, 0.5]), - ('random negative 4', [0.1, 0, 0.4]), -] - -STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - set_default_dataset_cls(DatasetDuckDB) - register_signal(TestEmbedding) - - # Unit test runs. - yield - - # Teardown. - clear_signal_registry() - - -def test_search_keyword(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - query = 'world' - result = dataset.select_rows( - searches=[KeywordSearch(path='text', query=query)], combine_columns=True) - - expected_signal_udf = SubstringSignal(query=query) - assert list(result) == [{ - 'text': enriched_item('hello world', {expected_signal_udf.key(): [lilac_span(6, 11)]}), - 'text2': 'again hello world' - }, { - 'text': enriched_item('looking for world in text', - {expected_signal_udf.key(): [lilac_span(12, 17)]}), - 'text2': 'again looking for world in text', - }] - - -def test_search_keyword_special_chars(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'This is 100%', - }, { - 'text': 'This has _underscore_', - }]) - - query = '100%' - result = dataset.select_rows( - searches=[KeywordSearch(path='text', query=query)], combine_columns=True) - - expected_signal_udf = SubstringSignal(query=query) - assert list(result) == [{ - 'text': enriched_item('This is 100%', {expected_signal_udf.key(): [lilac_span(8, 12)]}), - }] - - query = '_underscore_' - result = dataset.select_rows( - searches=[KeywordSearch(path='text', query=query)], combine_columns=True) - - expected_signal_udf = SubstringSignal(query=query) - assert list(result) == [{ - 'text': enriched_item('This has _underscore_', - {expected_signal_udf.key(): [lilac_span(9, 21)]}), - }] - - -def test_search_keyword_multiple(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - query_world = 'world' - query_looking_world = 'looking for world' - expected_world_udf = SubstringSignal(query=query_world) - expected_again_looking_udf = SubstringSignal(query=query_looking_world) - - result = dataset.select_rows( - searches=[ - KeywordSearch(path='text', query=query_world), - KeywordSearch(path='text2', query=query_looking_world), - ], - combine_columns=True) - - assert list(result) == [{ - 'text': enriched_item('looking for world in text', { - expected_world_udf.key(): [lilac_span(12, 17)], - }), - 'text2': enriched_item('again looking for world in text', - {expected_again_looking_udf.key(): [lilac_span(6, 23)]}) - }] - - -def test_search_keyword_with_filters(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(TEST_DATA) - - query = 'world' - result = dataset.select_rows(['*'], - filters=[(ROWID, 'in', ['1', '3'])], - searches=[KeywordSearch(path='text', query=query)], - combine_columns=True) - - expected_signal_udf = SubstringSignal(query=query) - assert list(result) == [ - { - 'text': enriched_item('hello world', {expected_signal_udf.key(): [lilac_span(6, 11)]}), - 'text2': 'again hello world' - }, - # The second row doesn't match the rowid filter. - ] - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Call the embedding function.""" - for example in data: - embedding = np.array(STR_EMBEDDINGS[cast(str, example)]) - yield [lilac_embedding(0, len(example), embedding)] - - -def test_semantic_search(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello world.', - }, { - 'text': 'hello world2.', - }]) - - test_embedding = TestEmbedding() - dataset.compute_signal(test_embedding, ('text')) - - query = 'hello2.' - result = dataset.select_rows( - searches=[SemanticSearch(path='text', query=query, embedding='test_embedding')], - combine_columns=True) - expected_signal_udf = SemanticSimilaritySignal(query=query, embedding='test_embedding') - assert list(result) == [ - # Results are sorted by score desc. - { - 'text': enriched_item('hello world2.', - {expected_signal_udf.key(): [lilac_span(0, 13, {'score': 3})]}) - }, - { - 'text': enriched_item('hello world.', - {expected_signal_udf.key(): [lilac_span(0, 12, {'score': 2})]}) - }, - ] - - -def test_concept_search(make_test_data: TestDataMaker, mocker: MockerFixture) -> None: - concept_model_mock = mocker.spy(LogisticEmbeddingModel, 'fit') - - dataset = make_test_data([{ - 'text': 'hello world.', - }, { - 'text': 'hello world2.', - }, { - 'text': 'random negative 1', - }, { - 'text': 'random negative 2', - }, { - 'text': 'random negative 3', - }, { - 'text': 'random negative 4', - }]) - - test_embedding = TestEmbedding() - dataset.compute_signal(test_embedding, ('text')) - - concept_db = DiskConceptDB() - concept_db.create(namespace='test_namespace', name='test_concept', type=SignalInputType.TEXT) - concept_db.edit( - 'test_namespace', 'test_concept', - ConceptUpdate(insert=[ - ExampleIn(label=False, text='hello world.'), - ExampleIn(label=True, text='hello world2.') - ])) - - result = dataset.select_rows( - columns=[ROWID, '*'], - searches=[ - ConceptSearch( - path='text', - concept_namespace='test_namespace', - concept_name='test_concept', - embedding='test_embedding') - ], - filters=[(ROWID, 'in', ['1', '2'])], - combine_columns=True) - expected_signal_udf = ConceptSignal( - namespace='test_namespace', concept_name='test_concept', embedding='test_embedding') - - assert list(result) == [ - # Results are sorted by score desc. - { - ROWID: '2', - 'text': enriched_item( - 'hello world2.', { - expected_signal_udf.key(): [lilac_span(0, 13, {'score': approx(0.75, abs=0.25)})], - 'test_namespace/test_concept/labels': [lilac_span(0, 13, {'label': True})] - }) - }, - { - ROWID: '1', - 'text': enriched_item( - 'hello world.', { - expected_signal_udf.key(): [lilac_span(0, 12, {'score': approx(0.25, abs=0.25)})], - 'test_namespace/test_concept/labels': [lilac_span(0, 12, {'label': False})] - }) - }, - ] - - (_, embeddings, labels) = concept_model_mock.call_args_list[-1].args - assert embeddings.shape == (2, 3) - assert labels == [ - # Explicit labels. - False, - True - ] - - -def test_concept_search_without_rowid(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello world.', - }, { - 'text': 'hello world2.', - }]) - - test_embedding = TestEmbedding() - dataset.compute_signal(test_embedding, ('text')) - - concept_db = DiskConceptDB() - concept_db.create(namespace='test_namespace', name='test_concept', type=SignalInputType.TEXT) - concept_db.edit( - 'test_namespace', 'test_concept', - ConceptUpdate(insert=[ - ExampleIn(label=False, text='hello world.'), - ExampleIn(label=True, text='hello world2.') - ])) - - result = dataset.select_rows( - columns=['text'], - searches=[ - ConceptSearch( - path='text', - concept_namespace='test_namespace', - concept_name='test_concept', - embedding='test_embedding') - ]) - - assert list(result) == [ - # Results are sorted by score desc. - { - 'text': 'hello world2.', - 'text.test_namespace/test_concept/test_embedding': [ - lilac_span(0, 13, {'score': approx(0.75, abs=0.25)}) - ], - 'text.test_namespace/test_concept/labels': [lilac_span(0, 13, {'label': True})] - }, - { - 'text': 'hello world.', - 'text.test_namespace/test_concept/test_embedding': [ - lilac_span(0, 12, {'score': approx(0.25, abs=0.25)}) - ], - 'text.test_namespace/test_concept/labels': [lilac_span(0, 12, {'label': False})] - }, - ] - - result = dataset.select_rows( - columns=['text'], - searches=[ - ConceptSearch( - path='text', - concept_namespace='test_namespace', - concept_name='test_concept', - embedding='test_embedding') - ], - combine_columns=True) - - assert list(result) == [ - # Results are sorted by score desc. - { - 'text': enriched_item( - 'hello world2.', { - 'test_namespace/test_concept/test_embedding': - [lilac_span(0, 13, {'score': approx(0.75, abs=0.25)})], - 'test_namespace/test_concept/labels': [lilac_span(0, 13, {'label': True})] - }) - }, - { - 'text': enriched_item( - 'hello world.', { - 'test_namespace/test_concept/test_embedding': - [lilac_span(0, 12, {'score': approx(0.25, abs=0.25)})], - 'test_namespace/test_concept/labels': [lilac_span(0, 12, {'label': False})] - }) - }, - ] - - -def test_concept_search_sort_by_rowid(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello world.', - }, { - 'text': 'hello world2.', - }]) - - test_embedding = TestEmbedding() - dataset.compute_signal(test_embedding, ('text')) - - concept_db = DiskConceptDB() - concept_db.create(namespace='test_namespace', name='test_concept', type=SignalInputType.TEXT) - concept_db.edit( - 'test_namespace', 'test_concept', - ConceptUpdate(insert=[ - ExampleIn(label=False, text='hello world.'), - ExampleIn(label=True, text='hello world2.') - ])) - - result = dataset.select_rows( - columns=['text'], - searches=[ - ConceptSearch( - path='text', - concept_namespace='test_namespace', - concept_name='test_concept', - embedding='test_embedding') - ], - sort_by=[ROWID], - sort_order=SortOrder.ASC) - - assert list(result) == [ - # Results are sorted by rowid. - { - 'text': 'hello world.', - 'text.test_namespace/test_concept/test_embedding': [ - lilac_span(0, 12, {'score': approx(0.25, abs=0.25)}) - ], - 'text.test_namespace/test_concept/labels': [lilac_span(0, 12, {'label': False})] - }, - { - 'text': 'hello world2.', - 'text.test_namespace/test_concept/test_embedding': [ - lilac_span(0, 13, {'score': approx(0.75, abs=0.25)}) - ], - 'text.test_namespace/test_concept/labels': [lilac_span(0, 13, {'label': True})] - } - ] - - -def test_sort_override_search(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello world.', - 'value': 10 - }, { - 'text': 'hello world2.', - 'value': 20 - }]) - - test_embedding = TestEmbedding() - dataset.compute_signal(test_embedding, ('text')) - - query = 'hello2.' - search = SemanticSearch(path='text', query=query, embedding='test_embedding') - - expected_signal_udf = SemanticSimilaritySignal(query=query, embedding='test_embedding') - expected_item_1 = { - 'text': enriched_item('hello world.', - {expected_signal_udf.key(): [lilac_span(0, 12, {'score': 2.0})]}), - 'value': 10 - } - expected_item_2 = { - 'text': enriched_item('hello world2.', - {expected_signal_udf.key(): [lilac_span(0, 13, {'score': 3.0})]}), - 'value': 20 - } - - sort_order = SortOrder.ASC - result = dataset.select_rows( - searches=[search], sort_by=[('value',)], sort_order=sort_order, combine_columns=True) - assert list(result) == [ - # Results are sorted by score ascending. - expected_item_1, - expected_item_2 - ] - - sort_order = SortOrder.DESC - result = dataset.select_rows( - searches=[search], sort_by=[('text',)], sort_order=sort_order, combine_columns=True) - assert list(result) == [ - # Results are sorted by score descending. - expected_item_2, - expected_item_1 - ] - - -def test_search_keyword_and_semantic(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello world.', - }, { - 'text': 'hello world2.', - }]) - - test_embedding = TestEmbedding() - dataset.compute_signal(test_embedding, ('text')) - - query = 'hello2.' - keyword_query = 'rld2' - result = dataset.select_rows( - searches=[ - SemanticSearch(path='text', query=query, embedding='test_embedding'), - KeywordSearch(path='text', query=keyword_query) - ], - combine_columns=True) - expected_semantic_signal = SemanticSimilaritySignal(query=query, embedding='test_embedding') - expected_keyword_signal = SubstringSignal(query=keyword_query) - assert list(result) == [ - # Results are sorted by score desc. - { - 'text': enriched_item( - 'hello world2.', { - expected_semantic_signal.key(): [lilac_span(0, 13, {'score': 3})], - expected_keyword_signal.key(): [lilac_span(8, 12)], - }) - }, - # rowid '1' is not returned because it does not match the keyword query. - ] diff --git a/lilac/data/dataset_select_rows_sort_test.py b/lilac/data/dataset_select_rows_sort_test.py deleted file mode 100644 index ff0d65186f3ba6a3f9b6451d73ce931379083d25..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_select_rows_sort_test.py +++ /dev/null @@ -1,771 +0,0 @@ -"""Tests for dataset.select_rows(sort_by=...).""" - -from typing import Iterable, Optional, Sequence, cast - -import numpy as np -import pytest -from typing_extensions import override - -from ..embeddings.vector_store import VectorDBIndex -from ..schema import ( - ROWID, - Field, - Item, - PathKey, - RichData, - SignalInputType, - VectorKey, - field, - lilac_embedding, - lilac_span, -) -from ..signal import ( - TextEmbeddingSignal, - TextSignal, - VectorSignal, - clear_signal_registry, - register_signal, -) -from .dataset import Column, SortOrder -from .dataset_test_utils import TestDataMaker, enriched_item - - -class TestSignal(TextSignal): - name = 'test_signal' - - def fields(self) -> Field: - return field(fields={'len': 'int32', 'is_all_cap': 'boolean'}) - - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text_content in data: - yield {'len': len(text_content), 'is_all_cap': text_content.isupper()} - - -class TestPrimitiveSignal(TextSignal): - name = 'primitive_signal' - - def fields(self) -> Field: - return field('int32') - - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text_content in data: - yield len(text_content) + 1 - - -class NestedArraySignal(TextSignal): - name = 'nested_array' - - def fields(self) -> Field: - return field(fields=[['int32']]) - - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text_content in data: - yield [[len(text_content) + 1], [len(text_content)]] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_signal(TestSignal) - register_signal(TestPrimitiveSignal) - register_signal(NestedArraySignal) - register_signal(TopKEmbedding) - # Unit test runs. - yield - # Teardown. - clear_signal_registry() - - -def test_sort_by_source_no_alias_no_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'erased': True, - 'score': 4.1, - 'document': { - 'num_pages': 4, - 'header': { - 'title': 'c' - } - } - }, { - 'erased': False, - 'score': 3.5, - 'document': { - 'num_pages': 5, - 'header': { - 'title': 'b' - } - }, - }, { - 'erased': True, - 'score': 3.7, - 'document': { - 'num_pages': 3, - 'header': { - 'title': 'a' - } - }, - }]) - - # Sort by bool. - result = dataset.select_rows(columns=[ROWID], sort_by=['erased'], sort_order=SortOrder.ASC) - assert list(result) == [{ROWID: '2'}, {ROWID: '1'}, {ROWID: '3'}] - result = dataset.select_rows(columns=[ROWID], sort_by=['erased'], sort_order=SortOrder.DESC) - assert list(result) == [{ROWID: '1'}, {ROWID: '3'}, {ROWID: '2'}] - - # Sort by float. - result = dataset.select_rows(columns=[ROWID], sort_by=['score'], sort_order=SortOrder.ASC) - assert list(result) == [{ROWID: '2'}, {ROWID: '3'}, {ROWID: '1'}] - result = dataset.select_rows(columns=[ROWID], sort_by=['score'], sort_order=SortOrder.DESC) - assert list(result) == [{ROWID: '1'}, {ROWID: '3'}, {ROWID: '2'}] - - # Sort by nested int. - result = dataset.select_rows( - columns=[ROWID], sort_by=['document.num_pages'], sort_order=SortOrder.ASC) - assert list(result) == [{ROWID: '3'}, {ROWID: '1'}, {ROWID: '2'}] - result = dataset.select_rows( - columns=[ROWID], sort_by=['document.num_pages'], sort_order=SortOrder.DESC) - assert list(result) == [{ROWID: '2'}, {ROWID: '1'}, {ROWID: '3'}] - - # Sort by double nested string. - result = dataset.select_rows( - columns=[ROWID], sort_by=['document.header.title'], sort_order=SortOrder.ASC) - assert list(result) == [{ROWID: '3'}, {ROWID: '2'}, {ROWID: '1'}] - result = dataset.select_rows( - columns=[ROWID], sort_by=['document.header.title'], sort_order=SortOrder.DESC) - assert list(result) == [{ROWID: '1'}, {ROWID: '2'}, {ROWID: '3'}] - - -def test_sort_by_signal_no_alias_no_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'HEY'}, {'text': 'everyone'}, {'text': 'HI'}]) - - dataset.compute_signal(TestSignal(), 'text') - - # Sort by `signal.len`. - result = dataset.select_rows( - columns=[ROWID], sort_by=['text.test_signal.len'], sort_order=SortOrder.ASC) - assert list(result) == [{ROWID: '3'}, {ROWID: '1'}, {ROWID: '2'}] - result = dataset.select_rows( - columns=[ROWID], sort_by=['text.test_signal.len'], sort_order=SortOrder.DESC) - assert list(result) == [{ROWID: '2'}, {ROWID: '1'}, {ROWID: '3'}] - - # Sort by `signal.is_all_cap`. - result = dataset.select_rows( - columns=[ROWID], sort_by=['text.test_signal.is_all_cap'], sort_order=SortOrder.ASC) - assert list(result) == [{ROWID: '2'}, {ROWID: '1'}, {ROWID: '3'}] - result = dataset.select_rows( - columns=[ROWID], sort_by=['text.test_signal.is_all_cap'], sort_order=SortOrder.DESC) - assert list(result) == [{ROWID: '1'}, {ROWID: '3'}, {ROWID: '2'}] - - -def test_sort_by_signal_alias_no_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'HEY'}, {'text': 'everyone'}, {'text': 'HI'}]) - - dataset.compute_signal(TestSignal(), 'text') - - # Sort by `signal.len`. - signal_alias = Column('text.test_signal', alias='signal') - result = dataset.select_rows( - columns=[signal_alias], sort_by=['signal.len'], sort_order=SortOrder.ASC) - assert list(result) == [{ - 'signal': { - 'len': 2, - 'is_all_cap': True - } - }, { - 'signal': { - 'len': 3, - 'is_all_cap': True - } - }, { - 'signal': { - 'len': 8, - 'is_all_cap': False - } - }] - result = dataset.select_rows( - columns=[signal_alias], sort_by=['signal.len'], sort_order=SortOrder.DESC) - assert list(result) == [{ - 'signal': { - 'len': 8, - 'is_all_cap': False - } - }, { - 'signal': { - 'len': 3, - 'is_all_cap': True - } - }, { - 'signal': { - 'len': 2, - 'is_all_cap': True - } - }] - - -def test_sort_by_enriched_alias_no_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'HEY'}, {'text': 'everyone'}, {'text': 'HI'}]) - - dataset.compute_signal(TestSignal(), 'text') - - # Sort by `document.test_signal.is_all_cap` where 'document' is an alias to 'text'. - text_alias = Column('text', alias='document') - result = dataset.select_rows( - columns=[text_alias], - sort_by=['document.test_signal.is_all_cap'], - sort_order=SortOrder.ASC, - combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('everyone', {'test_signal': { - 'len': 8, - 'is_all_cap': False - }}) - }, { - 'text': enriched_item('HEY', {'test_signal': { - 'len': 3, - 'is_all_cap': True - }}) - }, { - 'text': enriched_item('HI', {'test_signal': { - 'len': 2, - 'is_all_cap': True - }}) - }] - - result = dataset.select_rows( - columns=[text_alias], - sort_by=['document.test_signal.is_all_cap'], - sort_order=SortOrder.DESC, - combine_columns=True) - # Aliases are ignored when combining columns. - assert list(result) == [{ - 'text': enriched_item('HEY', {'test_signal': { - 'len': 3, - 'is_all_cap': True - }}) - }, { - 'text': enriched_item('HI', {'test_signal': { - 'len': 2, - 'is_all_cap': True - }}) - }, { - 'text': enriched_item('everyone', {'test_signal': { - 'len': 8, - 'is_all_cap': False - }}) - }] - - -def test_sort_by_udf_alias_no_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'HEY'}, {'text': 'everyone'}, {'text': 'HI'}]) - - # Equivalent to: SELECT `TestSignal(text) AS udf`. - text_udf = Column('text', signal_udf=TestSignal(), alias='udf') - # Sort by `udf.len`, where `udf` is an alias to `TestSignal(text)`. - result = dataset.select_rows(['*', text_udf], sort_by=['udf.len'], sort_order=SortOrder.ASC) - assert list(result) == [{ - 'text': 'HI', - 'udf': { - 'len': 2, - 'is_all_cap': True - } - }, { - 'text': 'HEY', - 'udf': { - 'len': 3, - 'is_all_cap': True - } - }, { - 'text': 'everyone', - 'udf': { - 'len': 8, - 'is_all_cap': False - } - }] - - -def test_sort_by_udf_no_alias_no_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'HEY'}, {'text': 'everyone'}, {'text': 'HI'}]) - - text_udf = Column('text', signal_udf=TestSignal()) - # Sort by `text.test_signal.len`, produced by executing the udf `TestSignal(text)`. - result = dataset.select_rows(['*', text_udf], - sort_by=[('text', 'test_signal', 'len')], - sort_order=SortOrder.ASC, - combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('HI', {'test_signal': { - 'len': 2, - 'is_all_cap': True - }}), - }, { - 'text': enriched_item('HEY', {'test_signal': { - 'len': 3, - 'is_all_cap': True - }}), - }, { - 'text': enriched_item('everyone', {'test_signal': { - 'len': 8, - 'is_all_cap': False - }}), - }] - - # Sort descending. - result = dataset.select_rows(['*', text_udf], - sort_by=[('text', 'test_signal', 'len')], - sort_order=SortOrder.DESC, - combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('everyone', {'test_signal': { - 'len': 8, - 'is_all_cap': False - }}), - }, { - 'text': enriched_item('HEY', {'test_signal': { - 'len': 3, - 'is_all_cap': True - }}), - }, { - 'text': enriched_item('HI', {'test_signal': { - 'len': 2, - 'is_all_cap': True - }}), - }] - - -def test_sort_by_primitive_udf_alias_no_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'HEY'}, {'text': 'everyone'}, {'text': 'HI'}]) - - # Equivalent to: SELECT `TestPrimitiveSignal(text) AS udf`. - text_udf = Column('text', signal_udf=TestPrimitiveSignal(), alias='udf') - # Sort by the primitive value returned by the udf. - result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.ASC) - assert list(result) == [{ - 'text': 'HI', - 'udf': 3 - }, { - 'text': 'HEY', - 'udf': 4 - }, { - 'text': 'everyone', - 'udf': 9 - }] - - -def test_sort_by_source_non_leaf_errors(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'vals': [7, 1]}, {'vals': [3, 4]}, {'vals': [9, 0]}]) - - # Sort by repeated. - with pytest.raises(ValueError, match='Unable to sort by path'): - dataset.select_rows(columns=[ROWID], sort_by=['vals'], sort_order=SortOrder.ASC) - - -def test_sort_by_source_no_alias_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'vals': [[{ - 'score': 7 - }, { - 'score': 1 - }], [{ - 'score': 1 - }, { - 'score': 7 - }]] - }, { - 'vals': [[{ - 'score': 3 - }, { - 'score': 4 - }]] - }, { - 'vals': [[{ - 'score': 9 - }, { - 'score': 0 - }]] - }]) - - # Sort by repeated 'vals'. - result = dataset.select_rows( - columns=['vals'], sort_by=['vals.*.*.score'], sort_order=SortOrder.ASC) - assert list(result) == [{ - 'vals': [[{ - 'score': 9 - }, { - 'score': 0 - }]] - }, { - 'vals': [[{ - 'score': 7 - }, { - 'score': 1 - }], [{ - 'score': 1 - }, { - 'score': 7 - }]] - }, { - 'vals': [[{ - 'score': 3 - }, { - 'score': 4 - }]] - }] - - result = dataset.select_rows( - columns=['vals'], sort_by=['vals.*.*.score'], sort_order=SortOrder.DESC) - assert list(result) == [{ - 'vals': [[{ - 'score': 9 - }, { - 'score': 0 - }]] - }, { - 'vals': [[{ - 'score': 7 - }, { - 'score': 1 - }], [{ - 'score': 1 - }, { - 'score': 7 - }]] - }, { - 'vals': [[{ - 'score': 3 - }, { - 'score': 4 - }]] - }] - - -def test_sort_by_source_alias_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'vals': [[7, 1], [1, 7]]}, {'vals': [[3], [11]]}, {'vals': [[9, 0]]}]) - - # Sort by repeated 'vals'. - result = dataset.select_rows( - columns=[Column('vals', alias='scores')], sort_by=['scores.*.*'], sort_order=SortOrder.ASC) - assert list(result) == [{ - 'scores': [[9, 0]] - }, { - 'scores': [[7, 1], [1, 7]] - }, { - 'scores': [[3], [11]] - }] - - result = dataset.select_rows( - columns=[Column('vals', alias='scores')], sort_by=['scores.*.*'], sort_order=SortOrder.DESC) - assert list(result) == [{ - 'scores': [[3], [11]] - }, { - 'scores': [[9, 0]] - }, { - 'scores': [[7, 1], [1, 7]] - }] - - -def test_sort_by_udf_alias_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'HEY'}, {'text': 'everyone'}, {'text': 'HI'}]) - - # Equivalent to: SELECT `NestedArraySignal(text) AS udf`. - text_udf = Column('text', signal_udf=NestedArraySignal(), alias='udf') - # Sort by `udf.*.*`, where `udf` is an alias to `NestedArraySignal(text)`. - result = dataset.select_rows(['*', text_udf], sort_by=['udf.*.*'], sort_order=SortOrder.ASC) - assert list(result) == [{ - 'text': 'HI', - 'udf': [[3], [2]] - }, { - 'text': 'HEY', - 'udf': [[4], [3]] - }, { - 'text': 'everyone', - 'udf': [[9], [8]] - }] - result = dataset.select_rows(['*', text_udf], sort_by=['udf.*.*'], sort_order=SortOrder.DESC) - assert list(result) == [{ - 'text': 'everyone', - 'udf': [[9], [8]] - }, { - 'text': 'HEY', - 'udf': [[4], [3]] - }, { - 'text': 'HI', - 'udf': [[3], [2]] - }] - - -def test_sort_by_complex_signal_udf_alias_called_on_repeated(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'texts': [{ - 'text': 'eardrop' - }, { - 'text': 'I' - }] - }, { - 'texts': [{ - 'text': 'hey' - }, { - 'text': 'CARS' - }] - }, { - 'texts': [{ - 'text': 'everyone' - }, { - 'text': '' - }] - }]) - - # Equivalent to: SELECT `TestSignal(texts.*.text) AS udf`. - texts_udf = Column('texts.*.text', signal_udf=TestSignal(), alias='udf') - # Sort by `udf.len`, where `udf` is an alias to `TestSignal(texts.*.text)`. - result = dataset.select_rows(['*', texts_udf], - sort_by=['udf.len'], - sort_order=SortOrder.ASC, - combine_columns=True) - assert list(result) == [{ - 'texts': [{ - 'text': enriched_item('everyone', {'test_signal': { - 'len': 8, - 'is_all_cap': False - }}) - }, { - 'text': enriched_item('', {'test_signal': { - 'len': 0, - 'is_all_cap': False - }}) - }] - }, { - 'texts': [{ - 'text': enriched_item('eardrop', {'test_signal': { - 'len': 7, - 'is_all_cap': False - }}) - }, { - 'text': enriched_item('I', {'test_signal': { - 'len': 1, - 'is_all_cap': True - }}) - }] - }, { - 'texts': [{ - 'text': enriched_item('hey', {'test_signal': { - 'len': 3, - 'is_all_cap': False - }}) - }, { - 'text': enriched_item('CARS', {'test_signal': { - 'len': 4, - 'is_all_cap': True - }}) - }] - }] - - -def test_sort_by_primitive_signal_udf_alias_called_on_repeated( - make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'texts': [{ - 'text': 'eardrop' - }, { - 'text': 'I' - }] - }, { - 'texts': [{ - 'text': 'hey' - }, { - 'text': 'CARS' - }] - }, { - 'texts': [{ - 'text': 'everyone' - }, { - 'text': '' - }] - }]) - - # Equivalent to: SELECT `TestPrimitiveSignal(texts.*.text) AS udf`. - texts_udf = Column('texts.*.text', signal_udf=TestPrimitiveSignal(), alias='udf') - # Sort by `udf`, where `udf` is an alias to `TestPrimitiveSignal(texts.*.text)`. - result = dataset.select_rows(['*', texts_udf], - sort_by=['udf'], - sort_order=SortOrder.ASC, - combine_columns=True) - assert list(result) == [{ - 'texts': [{ - 'text': enriched_item('everyone', {'primitive_signal': 9}) - }, { - 'text': enriched_item('', {'primitive_signal': 1}) - }] - }, { - 'texts': [{ - 'text': enriched_item('eardrop', {'primitive_signal': 8}) - }, { - 'text': enriched_item('I', {'primitive_signal': 2}) - }] - }, { - 'texts': [{ - 'text': enriched_item('hey', {'primitive_signal': 4}) - }, { - 'text': enriched_item('CARS', {'primitive_signal': 5}) - }] - }] - result = dataset.select_rows(['*', texts_udf], - sort_by=['udf'], - sort_order=SortOrder.DESC, - combine_columns=True) - assert list(result) == [{ - 'texts': [{ - 'text': enriched_item('everyone', {'primitive_signal': 9}) - }, { - 'text': enriched_item('', {'primitive_signal': 1}) - }] - }, { - 'texts': [{ - 'text': enriched_item('eardrop', {'primitive_signal': 8}) - }, { - 'text': enriched_item('I', {'primitive_signal': 2}) - }] - }, { - 'texts': [{ - 'text': enriched_item('hey', {'primitive_signal': 4}) - }, { - 'text': enriched_item('CARS', {'primitive_signal': 5}) - }] - }] - - -class TopKEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'topk_embedding' - - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Call the embedding function.""" - for example in data: - example = cast(str, example) - emb_spans: list[Item] = [] - for i, score in enumerate(example.split('_')): - start, end = i * 2, i * 2 + 1 - vector = np.array([int(score)]) - emb_spans.append(lilac_embedding(start, end, vector)) - yield emb_spans - - -class TopKSignal(VectorSignal): - """Compute scores along a given concept for documents.""" - name = 'topk_signal' - input_type = SignalInputType.TEXT - - _query = np.array([1]) - - def fields(self) -> Field: - return field(fields=[field('string_span', {'score': 'float32'})]) - - @override - def vector_compute(self, keys: Iterable[PathKey], - vector_index: VectorDBIndex) -> Iterable[Optional[Item]]: - all_vector_spans = vector_index.get(keys) - for vector_spans in all_vector_spans: - embeddings = np.array([vector_span['vector'] for vector_span in vector_spans]) - scores = embeddings.dot(self._query).reshape(-1) - res: Item = [] - for vector_span, score in zip(vector_spans, scores): - start, end = vector_span['span'] - res.append(lilac_span(start, end, {'score': score})) - yield res - - @override - def vector_compute_topk( - self, - topk: int, - vector_index: VectorDBIndex, - keys: Optional[Iterable[VectorKey]] = None) -> Sequence[tuple[VectorKey, Optional[Item]]]: - return vector_index.topk(self._query, topk, keys) - - -def test_sort_by_topk_embedding_udf(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'scores': '8_1', - }, { - 'scores': '3_5' - }, { - 'scores': '9_7' - }]) - - dataset.compute_signal(TopKEmbedding(), 'scores') - - # Equivalent to: SELECT `TopKSignal(scores, embedding='...') AS udf`. - signal = TopKSignal(embedding='topk_embedding') - text_udf = Column('scores', signal_udf=signal, alias='udf') - # Sort by `udf`, where `udf` is an alias to `TopKSignal(scores, embedding='...')`. - result = dataset.select_rows(['*', text_udf], - sort_by=['udf'], - sort_order=SortOrder.DESC, - limit=2, - combine_columns=True) - assert list(result) == [{ - 'scores': enriched_item( - '9_7', {signal.key(): [lilac_span(0, 1, {'score': 9.0}), - lilac_span(2, 3, {'score': 7.0})]}), - }, { - 'scores': enriched_item( - '8_1', {signal.key(): [lilac_span(0, 1, {'score': 8.0}), - lilac_span(2, 3, {'score': 1.0})]}), - }] - - # Same but set limit to 3. - result = dataset.select_rows(['*', text_udf], - sort_by=['udf'], - sort_order=SortOrder.DESC, - limit=3, - combine_columns=True) - assert list(result) == [{ - 'scores': enriched_item( - '9_7', {signal.key(): [lilac_span(0, 1, {'score': 9.0}), - lilac_span(2, 3, {'score': 7.0})]}), - }, { - 'scores': enriched_item( - '8_1', {signal.key(): [lilac_span(0, 1, {'score': 8.0}), - lilac_span(2, 3, {'score': 1.0})]}), - }, { - 'scores': enriched_item( - '3_5', {signal.key(): [lilac_span(0, 1, {'score': 3.0}), - lilac_span(2, 3, {'score': 5.0})]}), - }] - - -def test_sort_by_topk_udf_with_filter(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'scores': '8_1', - 'active': True - }, { - 'scores': '3_5', - 'active': True - }, { - 'scores': '9_7', - 'active': False - }]) - - dataset.compute_signal(TopKEmbedding(), 'scores') - - # Equivalent to: SELECT `TopKSignal(scores, embedding='...') AS udf`. - signal = TopKSignal(embedding='topk_embedding') - text_udf = Column('scores', signal_udf=signal, alias='udf') - # Sort by `udf`, where `udf` is an alias to `TopKSignal(scores, embedding='...')`. - result = dataset.select_rows(['*', text_udf], - sort_by=['udf'], - filters=[('active', 'equals', True)], - sort_order=SortOrder.DESC, - limit=2, - combine_columns=True) - # We make sure that '3' is not in the result, because it is not active, even though it has the - # highest topk score. - assert list(result) == [{ - 'active': True, - 'scores': enriched_item( - '8_1', {signal.key(): [lilac_span(0, 1, {'score': 8.0}), - lilac_span(2, 3, {'score': 1.0})]}) - }, { - 'active': True, - 'scores': enriched_item( - '3_5', {signal.key(): [lilac_span(0, 1, {'score': 3.0}), - lilac_span(2, 3, {'score': 5.0})]}) - }] diff --git a/lilac/data/dataset_select_rows_udf_test.py b/lilac/data/dataset_select_rows_udf_test.py deleted file mode 100644 index 137175b42d85e59e74116bc7e5bc4c0c18ea3a54..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_select_rows_udf_test.py +++ /dev/null @@ -1,346 +0,0 @@ -"""Tests for dataset.select_rows(udf_col).""" - -from typing import Iterable, Optional, cast - -import numpy as np -import pytest -from typing_extensions import override - -from ..embeddings.vector_store import VectorDBIndex -from ..schema import ( - ROWID, - Field, - Item, - RichData, - SignalInputType, - VectorKey, - field, - lilac_embedding, - lilac_span, -) -from ..signal import ( - TextEmbeddingSignal, - TextSignal, - TextSplitterSignal, - VectorSignal, - clear_signal_registry, - register_signal, -) -from .dataset import BinaryFilterTuple, Column -from .dataset_test_utils import TestDataMaker, enriched_item - -EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]), - ('hello2.', [1.0, 1.0, 0.0]), - ('hello world.', [1.0, 1.0, 1.0]), - ('hello world2.', [2.0, 1.0, 1.0])] - -STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Call the embedding function.""" - for example in data: - yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))] - - -class LengthSignal(TextSignal): - name = 'length_signal' - - _call_count: int = 0 - - def fields(self) -> Field: - return field('int32') - - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text_content in data: - self._call_count += 1 - yield len(text_content) - - -class TestSignal(TextSignal): - name = 'test_signal' - - @override - def fields(self) -> Field: - return field(fields={'len': 'int32', 'flen': 'float32'}) - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data] - - -class TestEmbeddingSumSignal(VectorSignal): - """Sums the embeddings to return a single floating point value.""" - name = 'test_embedding_sum' - input_type = SignalInputType.TEXT - - @override - def fields(self) -> Field: - return field('float32') - - @override - def vector_compute(self, keys: Iterable[VectorKey], - vector_index: VectorDBIndex) -> Iterable[Item]: - # The signal just sums the values of the embedding. - all_vector_spans = vector_index.get(keys) - for vector_spans in all_vector_spans: - yield vector_spans[0]['vector'].sum() - - -class ComputedKeySignal(TextSignal): - name = 'computed_key' - - @override - def fields(self) -> Field: - return field('int64') - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text in data: - yield 1 - - def key(self, is_computed_signal: Optional[bool] = False) -> str: - return f'key_{is_computed_signal}' - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_signal(LengthSignal) - register_signal(TestSplitter) - register_signal(TestEmbedding) - register_signal(TestSignal) - register_signal(TestEmbeddingSumSignal) - register_signal(ComputedKeySignal) - - # Unit test runs. - yield - # Teardown. - clear_signal_registry() - - -def test_udf(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) - - signal_col = Column('text', signal_udf=TestSignal()) - result = dataset.select_rows(['text', signal_col]) - - assert list(result) == [{ - 'text': 'hello', - 'text.test_signal': { - 'len': 5, - 'flen': 5.0 - } - }, { - 'text': 'everybody', - 'text.test_signal': { - 'len': 9, - 'flen': 9.0 - } - }] - - -def test_udf_with_filters(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) - - signal_col = Column('text', signal_udf=TestSignal()) - # Filter by source feature. - filters: list[BinaryFilterTuple] = [('text', 'equals', 'everybody')] - result = dataset.select_rows(['text', signal_col], filters=filters) - assert list(result) == [{'text': 'everybody', 'text.test_signal': {'len': 9, 'flen': 9.0}}] - - -def test_udf_with_rowid_filter(make_test_data: TestDataMaker) -> None: - - dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) - - # Filter by a specific rowid. - filters: list[BinaryFilterTuple] = [(ROWID, 'equals', '1')] - udf_col = Column('text', signal_udf=LengthSignal()) - result = dataset.select_rows([ROWID, 'text', udf_col], filters=filters) - assert list(result) == [{ROWID: '1', 'text': 'hello', 'text.length_signal': 5}] - assert cast(LengthSignal, udf_col.signal_udf)._call_count == 1 - - filters = [(ROWID, 'equals', '2')] - result = dataset.select_rows([ROWID, 'text', udf_col], filters=filters) - assert list(result) == [{ROWID: '2', 'text': 'everybody', 'text.length_signal': 9}] - assert cast(LengthSignal, udf_col.signal_udf)._call_count == 1 + 1 - - # No filters. - result = dataset.select_rows([ROWID, 'text', udf_col]) - assert list(result) == [{ - ROWID: '1', - 'text': 'hello', - 'text.length_signal': 5 - }, { - ROWID: '2', - 'text': 'everybody', - 'text.length_signal': 9 - }] - assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 + 2 - - -def test_udf_with_rowid_filter_repeated(make_test_data: TestDataMaker) -> None: - - dataset = make_test_data([{'text': ['hello', 'hi']}, {'text': ['everybody', 'bye', 'test']}]) - - # Filter by a specific rowid. - filters: list[BinaryFilterTuple] = [(ROWID, 'equals', '1')] - udf_col = Column(('text', '*'), signal_udf=LengthSignal()) - result = dataset.select_rows([ROWID, 'text', udf_col], filters=filters) - assert list(result) == [{ROWID: '1', 'text': ['hello', 'hi'], 'text.length_signal': [5, 2]}] - assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 - - # Filter by a specific rowid. - filters = [(ROWID, 'equals', '2')] - result = dataset.select_rows([ROWID, 'text', udf_col], filters=filters) - assert list(result) == [{ - ROWID: '2', - 'text': ['everybody', 'bye', 'test'], - 'text.length_signal': [9, 3, 4] - }] - assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 + 3 - - -def test_udf_deeply_nested(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': [['hello'], ['hi', 'bye']] - }, { - 'text': [['everybody', 'bye'], ['test']] - }]) - - udf_col = Column(('text', '*', '*'), signal_udf=LengthSignal()) - result = dataset.select_rows([udf_col]) - assert list(result) == [{ - 'text.length_signal': [[5], [2, 3]] - }, { - 'text.length_signal': [[9, 3], [4]] - }] - assert cast(LengthSignal, udf_col.signal_udf)._call_count == 6 - - -def test_udf_with_embedding(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello.', - }, { - 'text': 'hello2.', - }]) - - dataset.compute_signal(TestEmbedding(), 'text') - - signal_col = Column('text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding')) - result = dataset.select_rows(['text', signal_col]) - - expected_result: list[Item] = [{ - 'text': 'hello.', - 'text.test_embedding_sum(embedding=test_embedding)': 1.0 - }, { - 'text': 'hello2.', - 'text.test_embedding_sum(embedding=test_embedding)': 2.0 - }] - assert list(result) == expected_result - - # Select rows with alias. - signal_col = Column( - 'text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'), alias='emb_sum') - result = dataset.select_rows(['text', signal_col]) - expected_result = [{'text': 'hello.', 'emb_sum': 1.0}, {'text': 'hello2.', 'emb_sum': 2.0}] - assert list(result) == expected_result - - -def test_udf_with_nested_embedding(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': ['hello.', 'hello world.'], - }, { - 'text': ['hello world2.', 'hello2.'], - }]) - - dataset.compute_signal(TestEmbedding(), ('text', '*')) - - signal_col = Column(('text', '*'), signal_udf=TestEmbeddingSumSignal(embedding='test_embedding')) - result = dataset.select_rows([('text', '*'), signal_col]) - expected_result = [{ - 'text.*': ['hello.', 'hello world.'], - 'text.test_embedding_sum(embedding=test_embedding)': [1.0, 3.0] - }, { - 'text.*': ['hello world2.', 'hello2.'], - 'text.test_embedding_sum(embedding=test_embedding)': [4.0, 2.0] - }] - assert list(result) == expected_result - - -def test_udf_throws_without_precomputing(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello.', - }, { - 'text': 'hello2.', - }]) - - # Embedding is not precomputed, yet we ask for the embedding. - - signal_col = Column('text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding')) - - with pytest.raises(ValueError, match="No embedding found for path \\('text',\\)"): - dataset.select_rows(['text', signal_col]) - - -class TestSplitter(TextSplitterSignal): - """Split documents into sentence by splitting on period.""" - name = 'test_splitter' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - for text in data: - if not isinstance(text, str): - raise ValueError(f'Expected text to be a string, got {type(text)} instead.') - result: list[Item] = [] - for sentence in text.split('.'): - start = text.index(sentence) - end = start + len(sentence) - result.append(lilac_span(start, end)) - yield result - - -def test_udf_after_precomputed_split(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'sentence 1. sentence 2 is longer', - }, { - 'text': 'sentence 1 is longer. sent2 is short', - }]) - dataset.compute_signal(TestSplitter(), 'text') - udf = Column('text', signal_udf=LengthSignal()) - result = dataset.select_rows(['*', udf], combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('sentence 1. sentence 2 is longer', { - 'length_signal': 32, - 'test_splitter': [lilac_span(0, 10), lilac_span(11, 32)] - }) - }, { - 'text': enriched_item('sentence 1 is longer. sent2 is short', { - 'length_signal': 36, - 'test_splitter': [lilac_span(0, 20), lilac_span(21, 36)] - }) - }] - - -def test_is_computed_signal_key(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello.', - }, { - 'text': 'hello2.', - }]) - - signal_col = Column('text', signal_udf=ComputedKeySignal()) - result = dataset.select_rows(['text', signal_col]) - assert list(result) == [{ - 'text': 'hello.', - 'text.key_False': 1 - }, { - 'text': 'hello2.', - 'text.key_False': 1 - }] diff --git a/lilac/data/dataset_stats_test.py b/lilac/data/dataset_stats_test.py deleted file mode 100644 index 427528044e253cabe2a10a5813017865f4d774fb..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_stats_test.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Tests for dataset.stats().""" - -from datetime import datetime -from typing import Any, cast - -import pytest -from pytest_mock import MockerFixture - -from ..schema import Item, schema -from . import dataset as dataset_module -from .dataset import StatsResult -from .dataset_test_utils import TestDataMaker - -SIMPLE_ITEMS: list[Item] = [{ - 'str': 'a', - 'int': 1, - 'bool': False, - 'float': 3.0, -}, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 2.0 -}, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 1.0 -}, { - 'float': float('nan') -}] - - -def test_simple_stats(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(SIMPLE_ITEMS) - - result = dataset.stats(leaf_path='str') - assert result == StatsResult( - path=('str',), total_count=3, approx_count_distinct=2, avg_text_length=1) - - result = dataset.stats(leaf_path='float') - assert result == StatsResult( - path=('float',), total_count=4, approx_count_distinct=4, min_val=1.0, max_val=3.0) - - result = dataset.stats(leaf_path='bool') - assert result == StatsResult(path=('bool',), total_count=3, approx_count_distinct=2) - - result = dataset.stats(leaf_path='int') - assert result == StatsResult( - path=('int',), total_count=3, approx_count_distinct=2, min_val=1, max_val=2) - - -def test_nested_stats(make_test_data: TestDataMaker) -> None: - nested_items: list[Item] = [ - { - 'name': 'Name1', - 'addresses': [{ - 'zips': [5, 8] - }] - }, - { - 'name': 'Name2', - 'addresses': [{ - 'zips': [3] - }, { - 'zips': [11, 8] - }] - }, - { - 'name': 'Name2', - 'addresses': [] - }, # No addresses. - { - 'name': 'Name2', - 'addresses': [{ - 'zips': [] - }] - } # No zips in the first address. - ] - nested_schema = schema({'name': 'string', 'addresses': [{'zips': ['int32']}]}) - dataset = make_test_data(nested_items, schema=nested_schema) - - result = dataset.stats(leaf_path='name') - assert result == StatsResult( - path=('name',), total_count=4, approx_count_distinct=2, avg_text_length=5) - - result = dataset.stats(leaf_path='addresses.*.zips.*') - assert result == StatsResult( - path=('addresses', '*', 'zips', '*'), - total_count=5, - approx_count_distinct=4, - min_val=3, - max_val=11) - - -def test_stats_approximation(make_test_data: TestDataMaker, mocker: MockerFixture) -> None: - sample_size = 5 - mocker.patch(f'{dataset_module.__name__}.TOO_MANY_DISTINCT', sample_size) - - nested_items: list[Item] = [{'feature': str(i)} for i in range(sample_size * 10)] - nested_schema = schema({'feature': 'string'}) - dataset = make_test_data(nested_items, schema=nested_schema) - - result = dataset.stats(leaf_path='feature') - assert result == StatsResult( - path=('feature',), total_count=50, approx_count_distinct=50, avg_text_length=1) - - -def test_error_handling(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(SIMPLE_ITEMS) - - with pytest.raises(ValueError, match='leaf_path must be provided'): - dataset.stats(cast(Any, None)) - - with pytest.raises(ValueError, match="Path \\('unknown',\\) not found in schema"): - dataset.stats(leaf_path='unknown') - - -def test_datetime(make_test_data: TestDataMaker) -> None: - items: list[Item] = [ - { - 'id': '1', - 'date': datetime(2023, 1, 1) - }, - { - 'id': '2', - 'date': datetime(2023, 1, 15) - }, - { - 'id': '2', - 'date': datetime(2023, 2, 1) - }, - { - 'id': '4', - 'date': datetime(2023, 3, 1) - }, - { - 'id': '5', - # Missing datetime. - } - ] - dataset = make_test_data(items) - result = dataset.stats('date') - assert result == StatsResult( - path=('date',), - total_count=4, - approx_count_distinct=4, - min_val=datetime(2023, 1, 1), - max_val=datetime(2023, 3, 1)) diff --git a/lilac/data/dataset_test.py b/lilac/data/dataset_test.py deleted file mode 100644 index b24146b6202a5f8ed34e2ce8f2700554e8e49486..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_test.py +++ /dev/null @@ -1,759 +0,0 @@ -"""Implementation-agnostic tests of the Dataset DB API.""" - -from typing import Iterable, Optional, cast - -import numpy as np -import pytest -from typing_extensions import override - -from ..schema import ROWID, Field, Item, RichData, field, lilac_embedding, schema -from ..signal import TextEmbeddingSignal, TextSignal, clear_signal_registry, register_signal -from .dataset import Column, DatasetManifest -from .dataset_test_utils import TEST_DATASET_NAME, TEST_NAMESPACE, TestDataMaker, enriched_item - -SIMPLE_ITEMS: list[Item] = [{ - 'str': 'a', - 'int': 1, - 'bool': False, - 'float': 3.0 -}, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 2.0 -}, { - 'str': 'b', - 'int': 2, - 'bool': True, - 'float': 1.0 -}] - -EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]), - ('hello2.', [1.0, 1.0, 0.0]), - ('hello world.', [1.0, 1.0, 1.0]), - ('hello world2.', [2.0, 1.0, 1.0])] - -STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Call the embedding function.""" - for example in data: - yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))] - - -class LengthSignal(TextSignal): - name = 'length_signal' - - _call_count: int = 0 - - def fields(self) -> Field: - return field('int32') - - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text_content in data: - self._call_count += 1 - yield len(text_content) - - -class TestSignal(TextSignal): - name = 'test_signal' - - @override - def fields(self) -> Field: - return field(fields={'len': 'int32', 'flen': 'float32'}) - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_signal(TestSignal) - register_signal(LengthSignal) - register_signal(SignalWithQuoteInIt) - register_signal(SignalWithDoubleQuoteInIt) - - # Unit test runs. - yield - - # Teardown. - clear_signal_registry() - - -def test_select_all_columns(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(SIMPLE_ITEMS) - - result = dataset.select_rows() - assert list(result) == SIMPLE_ITEMS - - -def test_select_subcols_with_dot_seperator(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{ - 'people': [{ - 'name': 'A', - 'address': { - 'zip': 1 - } - }, { - 'name': 'B', - 'address': { - 'zip': 2 - } - }] - }, { - 'people': [{ - 'name': 'C', - 'address': { - 'zip': 3 - } - }] - }] - dataset = make_test_data(items) - - result = dataset.select_rows(['people.*.name', 'people.*.address.zip']) - assert list(result) == [{ - 'people.*.name': ['A', 'B'], - 'people.*.address.zip': [1, 2] - }, { - 'people.*.name': ['C'], - 'people.*.address.zip': [3] - }] - - result = dataset.select_rows(['people.*.address.zip'], combine_columns=True) - assert list(result) == [{ - 'people': [{ - 'address': { - 'zip': 1 - } - }, { - 'address': { - 'zip': 2 - } - }] - }, { - 'people': [{ - 'address': { - 'zip': 3 - } - }] - }] - - result = dataset.select_rows(['people']) - assert list(result) == items - - -def test_select_subcols_with_escaped_dot(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{ - 'people.new': [{ - 'name': 'A' - }, { - 'name': 'B' - }] - }, { - 'people.new': [{ - 'name': 'C' - }] - }] - dataset = make_test_data(items) - - result = dataset.select_rows(['"people.new".*.name']) - assert list(result) == [{ - 'people.new.*.name': ['A', 'B'], - }, { - 'people.new.*.name': ['C'], - }] - - # Escape name even though it does not need to be. - result = dataset.select_rows(['"people.new".*."name"']) - assert list(result) == [{ - 'people.new.*.name': ['A', 'B'], - }, { - 'people.new.*.name': ['C'], - }] - - -def test_select_star(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{'name': 'A', 'info': {'age': 40}}, {'name': 'B', 'info': {'age': 42}}] - dataset = make_test_data(items) - - # Select *. - result = dataset.select_rows(['*']) - assert list(result) == items - - # Select (*,). - result = dataset.select_rows([('*',)]) - assert list(result) == items - - # Select *, plus a redundant `info` column. - result = dataset.select_rows(['*', 'info']) - assert list(result) == [{ - 'name': 'A', - 'info': { - 'age': 40 - }, - 'info_2': { - 'age': 40 - }, - }, { - 'name': 'B', - 'info': { - 'age': 42 - }, - 'info_2': { - 'age': 42 - }, - }] - - # Select * plus an inner `info.age` column. - result = dataset.select_rows(['*', ('info', 'age')]) - assert list(result) == [{ - 'name': 'A', - 'info': { - 'age': 40 - }, - 'info.age': 40 - }, { - 'name': 'B', - 'info': { - 'age': 42 - }, - 'info.age': 42 - }] - - -def test_select_star_with_combine_cols(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{'name': 'A', 'info': {'age': 40}}, {'name': 'B', 'info': {'age': 42}}] - dataset = make_test_data(items) - - # Select *. - result = dataset.select_rows(['*'], combine_columns=True) - assert list(result) == items - - # Select *, plus a redundant `info` column. - result = dataset.select_rows(['*', 'info'], combine_columns=True) - assert list(result) == items - - # Select * plus an inner `info.age` column. - result = dataset.select_rows(['*', ('info', 'age')], combine_columns=True) - assert list(result) == items - - # Select *, plus redundant `name`, plus a udf. - udf = Column('name', signal_udf=TestSignal()) - result = dataset.select_rows(['*', 'name', udf], combine_columns=True) - - assert list(result) == [{ - 'name': enriched_item('A', {'test_signal': { - 'len': 1, - 'flen': 1.0 - }}), - 'info': { - 'age': 40 - } - }, { - 'name': enriched_item('B', {'test_signal': { - 'len': 1, - 'flen': 1.0 - }}), - 'info': { - 'age': 42 - } - }] - - -def test_select_ids(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(SIMPLE_ITEMS) - - result = dataset.select_rows([ROWID]) - - assert list(result) == [{ROWID: '1'}, {ROWID: '2'}, {ROWID: '3'}] - - -def test_select_ids_with_limit_and_offset(make_test_data: TestDataMaker) -> None: - items: list[Item] = [{i: i} for i in range(10)] - dataset = make_test_data(items) - - result = dataset.select_rows([ROWID], offset=1, limit=3) - assert list(result) == [{ROWID: '2'}, {ROWID: '3'}, {ROWID: '4'}] - - result = dataset.select_rows([ROWID], offset=7, limit=2) - assert list(result) == [{ROWID: '8'}, {ROWID: '9'}] - - result = dataset.select_rows([ROWID], offset=9, limit=200) - assert list(result) == [{ROWID: '10'}] - - result = dataset.select_rows([ROWID], offset=10, limit=200) - assert list(result) == [] - - -def test_columns(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(SIMPLE_ITEMS) - - result = dataset.select_rows([ROWID, 'str', 'float']) - - assert list(result) == [{ - ROWID: '1', - 'str': 'a', - 'float': 3.0 - }, { - ROWID: '2', - 'str': 'b', - 'float': 2.0 - }, { - ROWID: '3', - 'str': 'b', - 'float': 1.0 - }] - - -def test_merge_values(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) - test_signal = TestSignal() - dataset.compute_signal(test_signal, 'text') - length_signal = LengthSignal() - dataset.compute_signal(length_signal, 'text') - - result = dataset.select_rows(['text'], combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('hello', { - 'length_signal': 5, - 'test_signal': { - 'len': 5, - 'flen': 5.0 - } - }) - }, { - 'text': enriched_item('everybody', { - 'length_signal': 9, - 'test_signal': { - 'len': 9, - 'flen': 9.0 - } - }), - }] - - # Test subselection. - result = dataset.select_rows( - ['text', ('text', 'test_signal', 'flen'), ('text', 'test_signal', 'len')]) - assert list(result) == [{ - 'text': 'hello', - 'text.test_signal.flen': 5.0, - 'text.test_signal.len': 5 - }, { - 'text': 'everybody', - 'text.test_signal.flen': 9.0, - 'text.test_signal.len': 9 - }] - - # Test subselection with combine_columns=True. - result = dataset.select_rows( - ['text', ('text', 'test_signal', 'flen'), ('text', 'test_signal', 'len')], combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('hello', { - 'length_signal': 5, - 'test_signal': { - 'len': 5, - 'flen': 5.0 - } - }) - }, { - 'text': enriched_item('everybody', { - 'length_signal': 9, - 'test_signal': { - 'len': 9, - 'flen': 9.0 - } - }), - }] - - # Test subselection with aliasing. - result = dataset.select_rows( - columns=['text', Column(('text', 'test_signal', 'len'), alias='metadata')]) - assert list(result) == [{'text': 'hello', 'metadata': 5}, {'text': 'everybody', 'metadata': 9}] - - result = dataset.select_rows(columns=['text'], combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('hello', { - 'length_signal': 5, - 'test_signal': { - 'len': 5, - 'flen': 5.0 - } - }) - }, { - 'text': enriched_item('everybody', { - 'length_signal': 9, - 'test_signal': { - 'len': 9, - 'flen': 9.0 - } - }) - }] - - -def test_enriched_select_all(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}]) - test_signal = TestSignal() - dataset.compute_signal(test_signal, 'text') - length_signal = LengthSignal() - dataset.compute_signal(length_signal, 'text') - - result = dataset.select_rows() - assert list(result) == [{ - 'text': 'hello', - 'text.length_signal': 5, - 'text.test_signal': { - 'len': 5, - 'flen': 5.0 - } - }, { - 'text': 'everybody', - 'text.length_signal': 9, - 'text.test_signal': { - 'len': 9, - 'flen': 9.0 - } - }] - - -def test_merge_array_values(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{'texts': ['hello', 'everybody']}, {'texts': ['a', 'bc', 'def']}]) - - test_signal = TestSignal() - dataset.compute_signal(test_signal, ('texts', '*')) - length_signal = LengthSignal() - dataset.compute_signal(length_signal, ('texts', '*')) - - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'texts': [ - field( - 'string', - fields={ - 'length_signal': field('int32', length_signal.dict()), - 'test_signal': field( - signal=test_signal.dict(), fields={ - 'len': 'int32', - 'flen': 'float32' - }) - }) - ], - }), - num_items=2) - - result = dataset.select_rows(['texts'], combine_columns=True) - assert list(result) == [{ - 'texts': [ - enriched_item('hello', { - 'length_signal': 5, - 'test_signal': { - 'len': 5, - 'flen': 5.0 - } - }), - enriched_item('everybody', { - 'length_signal': 9, - 'test_signal': { - 'len': 9, - 'flen': 9.0 - } - }) - ], - }, { - 'texts': [ - enriched_item('a', { - 'length_signal': 1, - 'test_signal': { - 'len': 1, - 'flen': 1.0 - } - }), - enriched_item('bc', { - 'length_signal': 2, - 'test_signal': { - 'len': 2, - 'flen': 2.0 - } - }), - enriched_item('def', { - 'length_signal': 3, - 'test_signal': { - 'len': 3, - 'flen': 3.0 - } - }) - ], - }] - - # Test subselection. - result = dataset.select_rows([('texts', '*'), ('texts', '*', 'length_signal'), - ('texts', '*', 'test_signal', 'flen')]) - assert list(result) == [{ - 'texts.*': ['hello', 'everybody'], - 'texts.*.test_signal.flen': [5.0, 9.0], - 'texts.*.length_signal': [5, 9] - }, { - 'texts.*': ['a', 'bc', 'def'], - 'texts.*.test_signal.flen': [1.0, 2.0, 3.0], - 'texts.*.length_signal': [1, 2, 3] - }] - - -def test_combining_columns(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello', - 'extra': { - 'text': { - 'length_signal': 5, - 'test_signal': { - 'len': 5, - 'flen': 5.0 - } - } - } - }, { - 'text': 'everybody', - 'extra': { - 'text': { - 'length_signal': 9, - 'test_signal': { - 'len': 9, - 'flen': 9.0 - } - } - } - }]) - - # Sub-select text and test_signal. - result = dataset.select_rows(['text', ('extra', 'text', 'test_signal')], combine_columns=True) - assert list(result) == [{ - 'text': 'hello', - 'extra': { - 'text': { - 'test_signal': { - 'len': 5, - 'flen': 5.0 - } - } - } - }, { - 'text': 'everybody', - 'extra': { - 'text': { - 'test_signal': { - 'len': 9, - 'flen': 9.0 - } - } - } - }] - - # Sub-select text and length_signal. - result = dataset.select_rows(['text', ('extra', 'text', 'length_signal')], combine_columns=True) - assert list(result) == [{ - 'text': 'hello', - 'extra': { - 'text': { - 'length_signal': 5 - } - } - }, { - 'text': 'everybody', - 'extra': { - 'text': { - 'length_signal': 9 - } - } - }] - - # Sub-select length_signal only. - result = dataset.select_rows([('extra', 'text', 'length_signal')], combine_columns=True) - assert list(result) == [{ - 'extra': { - 'text': { - 'length_signal': 5 - } - } - }, { - 'extra': { - 'text': { - 'length_signal': 9 - } - } - }] - - # Aliases are ignored when combing columns. - len_col = Column(('extra', 'text', 'length_signal'), alias='hello') - result = dataset.select_rows([len_col], combine_columns=True) - assert list(result) == [{ - 'extra': { - 'text': { - 'length_signal': 5 - } - } - }, { - 'extra': { - 'text': { - 'length_signal': 9 - } - } - }] - - # Works with UDFs and aliases are ignored. - udf_col = Column('text', alias='ignored', signal_udf=LengthSignal()) - result = dataset.select_rows(['text', udf_col], combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('hello', {'length_signal': 5}) - }, { - 'text': enriched_item('everybody', {'length_signal': 9}) - }] - - -def test_source_joined_with_named_signal(make_test_data: TestDataMaker) -> None: - dataset = make_test_data(SIMPLE_ITEMS) - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'str': 'string', - 'int': 'int32', - 'bool': 'boolean', - 'float': 'float32', - }), - num_items=3) - - test_signal = TestSignal() - dataset.compute_signal(test_signal, 'str') - - # Check the enriched dataset manifest has 'text' enriched. - assert dataset.manifest() == DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'str': field( - 'string', - fields={ - 'test_signal': field( - signal=test_signal.dict(), fields={ - 'len': 'int32', - 'flen': 'float32' - }) - }), - 'int': 'int32', - 'bool': 'boolean', - 'float': 'float32', - }), - num_items=3) - - result = dataset.select_rows(['str', Column(('str', 'test_signal'), alias='test_signal_on_str')]) - - assert list(result) == [{ - 'str': 'a', - 'test_signal_on_str': { - 'len': 1, - 'flen': 1.0 - } - }, { - 'str': 'b', - 'test_signal_on_str': { - 'len': 1, - 'flen': 1.0 - } - }, { - 'str': 'b', - 'test_signal_on_str': { - 'len': 1, - 'flen': 1.0 - } - }] - - -def test_invalid_column_paths(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': enriched_item('hello', {'test_signal': { - 'len': 5 - }}), - 'text2': [ - enriched_item('hello', {'test_signal': { - 'len': 5 - }}), - enriched_item('hi', {'test_signal': { - 'len': 2 - }}) - ], - }]) - - with pytest.raises(ValueError, match='Path part "invalid" not found in the dataset'): - dataset.select_rows([('text', 'test_signal', 'invalid')]) - - with pytest.raises(ValueError, match='Selecting a specific index of a repeated field'): - dataset.select_rows([('text2', '4', 'test_signal')]) - - -def test_signal_with_quote(make_test_data: TestDataMaker) -> None: - dataset = make_test_data([{ - 'text': 'hello', - }, { - 'text': 'world', - }]) - dataset.compute_signal(SignalWithQuoteInIt(), 'text') - dataset.compute_signal(SignalWithDoubleQuoteInIt(), 'text') - result = dataset.select_rows(['text'], combine_columns=True) - assert list(result) == [{ - 'text': enriched_item('hello', { - "test'signal": True, - 'test"signal': True - }) - }, { - 'text': enriched_item('world', { - "test'signal": True, - 'test"signal': True - }), - }] - - result = dataset.select_rows(['text', "text.test'signal", 'text.test"signal'], - combine_columns=False) - assert list(result) == [{ - 'text': 'hello', - "text.test'signal": True, - 'text.test"signal': True - }, { - 'text': 'world', - "text.test'signal": True, - 'text.test"signal': True - }] - - -class SignalWithQuoteInIt(TextSignal): - name = "test'signal" - - @override - def fields(self) -> Field: - return field('boolean') - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for d in data: - yield True - - -class SignalWithDoubleQuoteInIt(TextSignal): - name = 'test"signal' - - @override - def fields(self) -> Field: - return field('boolean') - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for d in data: - yield True diff --git a/lilac/data/dataset_utils_test.py b/lilac/data/dataset_utils_test.py deleted file mode 100644 index 90bc64055a270e125bec9d79914fe1f4a2f1ede8..0000000000000000000000000000000000000000 --- a/lilac/data/dataset_utils_test.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Tests for dataset utils.""" - -from typing import Iterable, Iterator - -from ..schema import PathTuple -from ..utils import chunks -from .dataset_utils import count_primitives, sparse_to_dense_compute, wrap_in_dicts - - -def test_count_nested() -> None: - a = [[1, 2], [[3]], [4, 5, 6]] - assert 6 == count_primitives(a) - - -def test_wrap_in_dicts_with_spec_of_one_repeated() -> None: - a = [[1, 2], [3], [4, 5, 5]] - spec: list[PathTuple] = [('a', 'b', 'c'), ('d',)] # Corresponds to a.b.c.*.d. - result = wrap_in_dicts(a, spec) - assert result == [{ - 'a': { - 'b': { - 'c': [{ - 'd': 1 - }, { - 'd': 2 - }] - } - } - }, { - 'a': { - 'b': { - 'c': [{ - 'd': 3 - }] - } - } - }, { - 'a': { - 'b': { - 'c': [{ - 'd': 4 - }, { - 'd': 5 - }, { - 'd': 5 - }] - } - } - }] - - -def test_wrap_in_dicts_with_spec_of_double_repeated() -> None: - a = [[[1, 2], [3, 4, 5]], [[6]], [[7], [8], [9, 10]]] - spec: list[PathTuple] = [('a', 'b'), tuple(), ('c',)] # Corresponds to a.b.*.*.c. - result = wrap_in_dicts(a, spec) - assert result == [{ - 'a': { - 'b': [[{ - 'c': 1 - }, { - 'c': 2 - }], [{ - 'c': 3 - }, { - 'c': 4 - }, { - 'c': 5 - }]] - } - }, { - 'a': { - 'b': [[{ - 'c': 6 - }]] - } - }, { - 'a': { - 'b': [[{ - 'c': 7 - }], [{ - 'c': 8 - }], [{ - 'c': 9 - }, { - 'c': 10 - }]] - } - }] - - -def test_sparse_to_dense_compute() -> None: - sparse_input = iter([None, 1, 7, None, None, 3, None, 5, None, None]) - - def func(xs: Iterable[int]) -> Iterable[int]: - for x in xs: - yield x + 1 - - out = sparse_to_dense_compute(sparse_input, func) - assert list(out) == [None, 2, 8, None, None, 4, None, 6, None, None] - - -def test_sparse_to_dense_compute_batching() -> None: - sparse_input = iter([None, 1, 7, None, None, 3, None, 5, None, None]) - - def func(xs: Iterable[int]) -> Iterable[int]: - for batch in chunks(xs, 2): - yield batch[0] + 1 - if len(batch) > 1: - yield batch[1] + 1 - - out = sparse_to_dense_compute(sparse_input, func) - assert list(out) == [None, 2, 8, None, None, 4, None, 6, None, None] - - -def test_fully_dense() -> None: - sparse_input = iter([1, 7, 3, 5]) - - def func(xs: Iterable[int]) -> Iterable[int]: - for x in xs: - yield x + 1 - - out = sparse_to_dense_compute(sparse_input, func) - assert list(out) == [2, 8, 4, 6] - - -def test_sparse_to_dense_compute_fully_sparse() -> None: - sparse_input = iter([None, None, None]) - - def func(xs: Iterable[int]) -> Iterable[int]: - for x in xs: - yield x + 1 - - out = sparse_to_dense_compute(sparse_input, func) - assert list(out) == [None, None, None] - - -def test_sparse_to_dense_compute_empty() -> None: - sparse_input: Iterator[int] = iter([]) - - def func(xs: Iterable[int]) -> Iterable[int]: - for x in xs: - yield x + 1 - - out = sparse_to_dense_compute(sparse_input, func) - assert list(out) == [] diff --git a/lilac/data_loader_test.py b/lilac/data_loader_test.py deleted file mode 100644 index 0718b2845fcb76fb5d4464d449e981f7107f7d73..0000000000000000000000000000000000000000 --- a/lilac/data_loader_test.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Tests for data_loader.py.""" - -import os -import pathlib -import uuid -from typing import Iterable - -import yaml -from pytest_mock import MockerFixture -from typing_extensions import override - -from .config import CONFIG_FILENAME, DatasetConfig, DatasetSettings, DatasetUISettings -from .data.dataset_duckdb import read_source_manifest -from .data.dataset_utils import parquet_filename -from .data_loader import process_source -from .schema import PARQUET_FILENAME_PREFIX, ROWID, Item, SourceManifest, schema -from .sources.source import Source, SourceSchema -from .test_utils import fake_uuid, read_items -from .utils import DATASETS_DIR_NAME - - -class TestSource(Source): - """A test source.""" - name = 'test_source' - - @override - def setup(self) -> None: - pass - - @override - def source_schema(self) -> SourceSchema: - """Return the source schema.""" - return SourceSchema(fields=schema({'x': 'int64', 'y': 'string'}).fields, num_items=2) - - @override - def process(self) -> Iterable[Item]: - return [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}] - - -def test_data_loader(tmp_path: pathlib.Path, mocker: MockerFixture) -> None: - mocker.patch.dict(os.environ, {'LILAC_DATA_PATH': str(tmp_path)}) - - mock_uuid = mocker.patch.object(uuid, 'uuid4', autospec=True) - mock_uuid.side_effect = [fake_uuid(b'1'), fake_uuid(b'2')] - - source = TestSource() - setup_mock = mocker.spy(TestSource, 'setup') - - output_dir, num_items = process_source( - tmp_path, DatasetConfig(namespace='test_namespace', name='test_dataset', source=source)) - - assert setup_mock.call_count == 1 - - assert output_dir == os.path.join(tmp_path, DATASETS_DIR_NAME, 'test_namespace', 'test_dataset') - assert num_items == 2 - - source_manifest = read_source_manifest(output_dir) - - assert source_manifest == SourceManifest( - files=[parquet_filename(PARQUET_FILENAME_PREFIX, 0, 1)], - data_schema=schema({ - 'x': 'int64', - 'y': 'string' - }), - ) - - items = read_items(output_dir, source_manifest.files, source_manifest.data_schema) - - assert items == [{ - ROWID: fake_uuid(b'1').hex, - 'x': 1, - 'y': 'ten' - }, { - ROWID: fake_uuid(b'2').hex, - 'x': 2, - 'y': 'twenty' - }] - - # Make sure the config yml file was written. - config_filepath = os.path.join(output_dir, CONFIG_FILENAME) - assert os.path.exists(config_filepath) - - with open(config_filepath) as f: - config = DatasetConfig(**yaml.safe_load(f)) - - assert config.dict() == DatasetConfig( - namespace='test_namespace', - name='test_dataset', - source=source, - # 'y' is the longest path, so should be set as the default setting. - settings=DatasetSettings(ui=DatasetUISettings(media_paths=[('y',)]))).dict() diff --git a/lilac/embeddings/__pycache__/__init__.cpython-39.pyc b/lilac/embeddings/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 3bbf3ce76faf97d57dd8de90c3b64a83a95e1883..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/cohere.cpython-39.pyc b/lilac/embeddings/__pycache__/cohere.cpython-39.pyc deleted file mode 100644 index daf6239484f56ff9f9b475cc44d84a155af33f9e..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/cohere.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/default_vector_stores.cpython-39.pyc b/lilac/embeddings/__pycache__/default_vector_stores.cpython-39.pyc deleted file mode 100644 index e801de01479f5c2a70c7fffe55a8cbeca6b6e32f..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/default_vector_stores.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/embedding.cpython-39.pyc b/lilac/embeddings/__pycache__/embedding.cpython-39.pyc deleted file mode 100644 index c341aa4cf8eab423b199f8359cffdb2a74040be3..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/embedding.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/embedding_test.cpython-39-pytest-7.4.0.pyc b/lilac/embeddings/__pycache__/embedding_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 733c8af26b16cbb04c2aabea8e80207448b88037..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/embedding_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/gte.cpython-39.pyc b/lilac/embeddings/__pycache__/gte.cpython-39.pyc deleted file mode 100644 index 02253a378e3f0ae4feab6e3f5a6b5b8e5eec81fe..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/gte.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/openai.cpython-39.pyc b/lilac/embeddings/__pycache__/openai.cpython-39.pyc deleted file mode 100644 index 59e1eb0b9fe5723261e10ec099cdfd09fe8d0c16..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/openai.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/palm.cpython-39.pyc b/lilac/embeddings/__pycache__/palm.cpython-39.pyc deleted file mode 100644 index 86b328784764840659439db477920e6b7588bd6c..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/palm.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/sbert.cpython-39.pyc b/lilac/embeddings/__pycache__/sbert.cpython-39.pyc deleted file mode 100644 index 803fa4f1e46f193968723973fe7a70af81611d69..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/sbert.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/transformer_utils.cpython-39.pyc b/lilac/embeddings/__pycache__/transformer_utils.cpython-39.pyc deleted file mode 100644 index dafbcd1ccb176ce1a8e6b14397657f22aa5a94df..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/transformer_utils.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/vector_store.cpython-39.pyc b/lilac/embeddings/__pycache__/vector_store.cpython-39.pyc deleted file mode 100644 index 4e3fa5afd6cec15272dfff298c7287e55d5cb743..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/vector_store.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/vector_store_hnsw.cpython-39.pyc b/lilac/embeddings/__pycache__/vector_store_hnsw.cpython-39.pyc deleted file mode 100644 index 3a8d4db4310ab9c1d3981f4b794a9e02363a9f51..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/vector_store_hnsw.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/vector_store_numpy.cpython-39.pyc b/lilac/embeddings/__pycache__/vector_store_numpy.cpython-39.pyc deleted file mode 100644 index b627b2a5657620a9018b1bb21a3c4f291168f218..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/vector_store_numpy.cpython-39.pyc and /dev/null differ diff --git a/lilac/embeddings/__pycache__/vector_store_test.cpython-39-pytest-7.4.0.pyc b/lilac/embeddings/__pycache__/vector_store_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 18a6c574a5e5bd05509184559f406850efad1baf..0000000000000000000000000000000000000000 Binary files a/lilac/embeddings/__pycache__/vector_store_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/embeddings/embedding_test.py b/lilac/embeddings/embedding_test.py deleted file mode 100644 index eada89f014ac9f97c9c4560c8efd07f816b7772c..0000000000000000000000000000000000000000 --- a/lilac/embeddings/embedding_test.py +++ /dev/null @@ -1,112 +0,0 @@ -"""Tests for embedding.py.""" - -import numpy as np - -from ..schema import lilac_embedding -from ..splitters.chunk_splitter import TextChunk -from .embedding import compute_split_embeddings - - -def char_splitter(text: str) -> list[TextChunk]: - return [(letter, (i, i + 1)) for i, letter in enumerate(text)] - - -def test_split_and_combine_text_embeddings_batch_across_two_docs() -> None: - docs = ['This is', '123'] - batch_size = 3 - - embed_fn_inputs: list[list[str]] = [] - - def embed_fn(batch: list[str]) -> list[np.ndarray]: - embed_fn_inputs.append(batch) - return [np.ones(1) for _ in batch] - - result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter)) - - # Each input to embed_fn is a batch of at most 3 letters. - assert embed_fn_inputs == [ - ['T', 'h', 'i'], - ['s', ' ', 'i'], - ['s', '1', '2'], - ['3'], - ] - - assert result == [ - [ - lilac_embedding(0, 1, np.array(1)), # T - lilac_embedding(1, 2, np.array(1)), # h - lilac_embedding(2, 3, np.array(1)), # i - lilac_embedding(3, 4, np.array(1)), # s - lilac_embedding(4, 5, np.array(1)), # ' ' - lilac_embedding(5, 6, np.array(1)), # i - lilac_embedding(6, 7, np.array(1)), # s - ], - [ - lilac_embedding(0, 1, np.array(1)), # 1 - lilac_embedding(1, 2, np.array(1)), # 2 - lilac_embedding(2, 3, np.array(1)), # 3 - ], - ] - - -def test_split_and_combine_text_embeddings_no_docs() -> None: - docs: list[str] = [] - batch_size = 3 - - embed_fn_inputs: list[list[str]] = [] - - def embed_fn(batch: list[str]) -> list[np.ndarray]: - embed_fn_inputs.append(batch) - return [np.ones(1) for _ in batch] - - result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter)) - assert embed_fn_inputs == [] - assert result == [] - - -def test_split_and_combine_text_embeddings_empty_docs() -> None: - docs: list[str] = ['', '', '123'] - batch_size = 3 - - embed_fn_inputs: list[list[str]] = [] - - def embed_fn(batch: list[str]) -> list[np.ndarray]: - embed_fn_inputs.append(batch) - return [np.ones(1) for _ in batch] - - result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter)) - assert embed_fn_inputs == [['1', '2', '3']] - - assert result == [ - None, - None, - [ - lilac_embedding(0, 1, np.array(1)), # 1 - lilac_embedding(1, 2, np.array(1)), # 2 - lilac_embedding(2, 3, np.array(1)), # 3 - ] - ] - - -def test_split_and_combine_text_embeddings_empty_docs_at_end() -> None: - docs: list[str] = ['123', '', ''] - batch_size = 3 - - embed_fn_inputs: list[list[str]] = [] - - def embed_fn(batch: list[str]) -> list[np.ndarray]: - embed_fn_inputs.append(batch) - return [np.ones(1) for _ in batch] - - result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter)) - assert embed_fn_inputs == [['1', '2', '3']] - - assert result == [ - [ - lilac_embedding(0, 1, np.array(1)), # 1 - lilac_embedding(1, 2, np.array(1)), # 2 - lilac_embedding(2, 3, np.array(1)), # 3 - ], - None, - None - ] diff --git a/lilac/embeddings/vector_store_test.py b/lilac/embeddings/vector_store_test.py deleted file mode 100644 index 6ca63d5872c9f301e2ba430538a64406cad8580b..0000000000000000000000000000000000000000 --- a/lilac/embeddings/vector_store_test.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Tests the vector store interface.""" - -from typing import Type, cast - -import numpy as np -import pytest -from sklearn.preprocessing import normalize - -from .vector_store import VectorStore -from .vector_store_hnsw import HNSWVectorStore -from .vector_store_numpy import NumpyVectorStore - -ALL_STORES = [NumpyVectorStore, HNSWVectorStore] - - -@pytest.mark.parametrize('store_cls', ALL_STORES) -class VectorStoreSuite: - - def test_get_all(self, store_cls: Type[VectorStore]) -> None: - store = store_cls() - - store.add([('a',), ('b',), ('c',)], np.array([[1, 2], [3, 4], [5, 6]])) - - np.testing.assert_array_equal( - store.get([('a',), ('b',), ('c',)]), np.array([[1, 2], [3, 4], [5, 6]])) - - def test_get_subset(self, store_cls: Type[VectorStore]) -> None: - store = store_cls() - - store.add([('a',), ('b',), ('c',)], np.array([[1, 2], [3, 4], [5, 6]])) - - np.testing.assert_array_equal(store.get([('b',), ('c',)]), np.array([[3, 4], [5, 6]])) - - def test_topk(self, store_cls: Type[VectorStore]) -> None: - store = store_cls() - embedding = cast(np.ndarray, normalize(np.array([[1, 0], [0, 1], [1, 1]]))) - query = np.array([0.9, 1]) - query /= np.linalg.norm(query) - topk = 3 - store.add([('a',), ('b',), ('c',)], embedding) - result = store.topk(query, topk) - assert [key for key, _ in result] == [('c',), ('b',), ('a',)] - assert [score for _, score in result] == pytest.approx([0.999, 0.743, 0.669], abs=1e-3) - - def test_topk_with_restricted_keys(self, store_cls: Type[VectorStore]) -> None: - store = store_cls() - embedding = np.array([[0.45, 0.89], [0.6, 0.8], [0.64, 0.77]]) - query = np.array([0.89, 0.45]) - topk = 3 - store.add([('a',), ('b',), ('c',)], embedding) - result = store.topk(query, topk, keys=[('b',), ('a',)]) - assert [key for key, _ in result] == [('b',), ('a',)] - assert [score for _, score in result] == pytest.approx([0.894, 0.801], 1e-3) - - result = store.topk(query, topk, keys=[('a',), ('b',)]) - assert [key for key, _ in result] == [('b',), ('a',)] - assert [score for _, score in result] == pytest.approx([0.894, 0.801], 1e-3) - - result = store.topk(query, topk, keys=[('a',), ('c',)]) - assert [key for key, _ in result] == [('c',), ('a',)] - assert [score for _, score in result] == pytest.approx([0.9161, 0.801], 1e-3) - - def test_topk_with_keys(self, store_cls: Type[VectorStore]) -> None: - store = store_cls() - embedding = np.array([[8], [9], [3], [10]]) - store.add([('a', 0), ('a', 1), ('b', 0), ('c', 0)], embedding) - query = np.array([1]) - result = store.topk(query, k=2, keys=[('b', 0), ('c', 0)]) - assert result == [(('c', 0), 10.0), (('b', 0), 3.0)] - - result = store.topk(query, k=10, keys=[('b', 0), ('a', 1), ('a', 0)]) - assert result == [(('a', 1), 9.0), (('a', 0), 8.0), (('b', 0), 3.0)] diff --git a/lilac/schema_test.py b/lilac/schema_test.py deleted file mode 100644 index 26a936b410de3e96a6ef0118b424d937284eca30..0000000000000000000000000000000000000000 --- a/lilac/schema_test.py +++ /dev/null @@ -1,262 +0,0 @@ -"""Tests for item.py.""" - -import pyarrow as pa -import pytest - -from .schema import ( - PATH_WILDCARD, - TEXT_SPAN_END_FEATURE, - TEXT_SPAN_START_FEATURE, - VALUE_KEY, - DataType, - Field, - Item, - arrow_schema_to_schema, - child_item_from_column_path, - column_paths_match, - field, - schema, - schema_to_arrow_schema, -) - -NESTED_TEST_SCHEMA = schema({ - 'person': { - 'name': 'string', - 'last_name': 'string_span', - # Contains a double nested array of primitives. - 'data': [['float32']], - # Contains a value and children. - 'description': field( - 'string', - fields={ - 'toxicity': 'float32', - 'sentences': [field('string_span', fields={'len': 'int32'})] - }) - }, - 'addresses': [{ - 'city': 'string', - 'zipcode': 'int16', - 'current': 'boolean', - 'locations': [{ - 'latitude': 'float16', - 'longitude': 'float64' - }] - }], - 'blob': 'binary' -}) -NESTED_TEST_ITEM: Item = { - 'person': { - 'name': 'Test Name', - 'last_name': (5, 9) - }, - 'addresses': [{ - 'city': 'a', - 'zipcode': 1, - 'current': False, - 'locations': [{ - 'latitude': 1.5, - 'longitude': 3.8 - }, { - 'latitude': 2.9, - 'longitude': 15.3 - }], - }, { - 'city': 'b', - 'zipcode': 2, - 'current': True, - 'locations': [{ - 'latitude': 11.2, - 'longitude': 20.1 - }, { - 'latitude': 30.1, - 'longitude': 40.2 - }], - }] -} - - -def test_field_ctor_validation() -> None: - with pytest.raises( - ValueError, match='One of "fields", "repeated_field", or "dtype" should be defined'): - Field() - - with pytest.raises(ValueError, match='Both "fields" and "repeated_field" should not be defined'): - Field( - fields={'name': Field(dtype=DataType.STRING)}, - repeated_field=Field(dtype=DataType.INT32), - ) - - with pytest.raises(ValueError, match=f'{VALUE_KEY} is a reserved field name'): - Field(fields={VALUE_KEY: Field(dtype=DataType.STRING)},) - - -def test_schema_leafs() -> None: - expected = { - ('addresses', PATH_WILDCARD, 'city'): Field(dtype=DataType.STRING), - ('addresses', PATH_WILDCARD, 'current'): Field(dtype=DataType.BOOLEAN), - ('addresses', PATH_WILDCARD, 'locations', PATH_WILDCARD, 'latitude'): - Field(dtype=DataType.FLOAT16), - ('addresses', PATH_WILDCARD, 'locations', PATH_WILDCARD, 'longitude'): - Field(dtype=DataType.FLOAT64), - ('addresses', PATH_WILDCARD, 'zipcode'): Field(dtype=DataType.INT16), - ('blob',): Field(dtype=DataType.BINARY), - ('person', 'name'): Field(dtype=DataType.STRING), - ('person', 'last_name'): Field(dtype=DataType.STRING_SPAN), - ('person', 'data', PATH_WILDCARD, PATH_WILDCARD): Field(dtype=DataType.FLOAT32), - ('person', 'description'): Field( - dtype=DataType.STRING, - fields={ - 'toxicity': Field(dtype=DataType.FLOAT32), - 'sentences': Field( - repeated_field=Field( - dtype=DataType.STRING_SPAN, fields={'len': Field(dtype=DataType.INT32)})) - }), - ('person', 'description', 'toxicity'): Field(dtype=DataType.FLOAT32), - ('person', 'description', 'sentences', PATH_WILDCARD): Field( - fields={'len': Field(dtype=DataType.INT32)}, dtype=DataType.STRING_SPAN), - ('person', 'description', 'sentences', PATH_WILDCARD, 'len'): Field(dtype=DataType.INT32), - } - assert NESTED_TEST_SCHEMA.leafs == expected - - -def test_schema_to_arrow_schema() -> None: - arrow_schema = schema_to_arrow_schema(NESTED_TEST_SCHEMA) - - assert arrow_schema == pa.schema({ - 'person': pa.struct({ - 'name': pa.string(), - # The dtype for STRING_SPAN is implemented as a struct with a {start, end}. - 'last_name': pa.struct({ - VALUE_KEY: pa.struct({ - TEXT_SPAN_START_FEATURE: pa.int32(), - TEXT_SPAN_END_FEATURE: pa.int32(), - }) - }), - 'data': pa.list_(pa.list_(pa.float32())), - 'description': pa.struct({ - 'toxicity': pa.float32(), - 'sentences': pa.list_( - pa.struct({ - 'len': pa.int32(), - VALUE_KEY: pa.struct({ - TEXT_SPAN_START_FEATURE: pa.int32(), - TEXT_SPAN_END_FEATURE: pa.int32(), - }) - })), - VALUE_KEY: pa.string(), - }) - }), - 'addresses': pa.list_( - pa.struct({ - 'city': pa.string(), - 'zipcode': pa.int16(), - 'current': pa.bool_(), - 'locations': pa.list_(pa.struct({ - 'latitude': pa.float16(), - 'longitude': pa.float64() - })), - })), - 'blob': pa.binary(), - }) - - -def test_arrow_schema_to_schema() -> None: - arrow_schema = pa.schema({ - 'person': pa.struct({ - 'name': pa.string(), - 'data': pa.list_(pa.list_(pa.float32())) - }), - 'addresses': pa.list_( - pa.struct({ - 'city': pa.string(), - 'zipcode': pa.int16(), - 'current': pa.bool_(), - 'locations': pa.list_(pa.struct({ - 'latitude': pa.float16(), - 'longitude': pa.float64() - })), - })), - 'blob': pa.binary(), - }) - expected_schema = schema({ - 'person': { - 'name': 'string', - 'data': [['float32']] - }, - 'addresses': [{ - 'city': 'string', - 'zipcode': 'int16', - 'current': 'boolean', - 'locations': [{ - 'latitude': 'float16', - 'longitude': 'float64', - }] - }], - 'blob': 'binary', - }) - assert arrow_schema_to_schema(arrow_schema) == expected_schema - - -def test_simple_schema_str() -> None: - assert str(schema({'person': 'string'})) == 'person: string' - - -def test_child_item_from_column_path() -> None: - assert child_item_from_column_path(NESTED_TEST_ITEM, - ('addresses', '0', 'locations', '0', 'longitude')) == 3.8 - assert child_item_from_column_path(NESTED_TEST_ITEM, ('addresses', '1', 'city')) == 'b' - - -def test_child_item_from_column_path_raises_wildcard() -> None: - with pytest.raises( - ValueError, match='cannot be called with a path that contains a repeated wildcard'): - child_item_from_column_path(NESTED_TEST_ITEM, ('addresses', PATH_WILDCARD, 'city')) - - -def test_column_paths_match() -> None: - assert column_paths_match(path_match=('person', 'name'), specific_path=('person', 'name')) is True - assert column_paths_match( - path_match=('person', 'name'), specific_path=('person', 'not_name')) is False - - # Wildcards work for structs. - assert column_paths_match( - path_match=(PATH_WILDCARD, 'name'), specific_path=('person', 'name')) is True - assert column_paths_match( - path_match=(PATH_WILDCARD, 'name'), specific_path=('person', 'not_name')) is False - - # Wildcards work for repeateds. - assert column_paths_match( - path_match=('person', PATH_WILDCARD, 'name'), specific_path=('person', '0', 'name')) is True - assert column_paths_match( - path_match=('person', PATH_WILDCARD, 'name'), - specific_path=('person', '0', 'not_name')) is False - - # Sub-path matches always return False. - assert column_paths_match(path_match=(PATH_WILDCARD,), specific_path=('person', 'name')) is False - assert column_paths_match( - path_match=( - 'person', - PATH_WILDCARD, - ), specific_path=('person', '0', 'name')) is False - - -def test_nested_schema_str() -> None: - - assert str(NESTED_TEST_SCHEMA) == """\ -person: - name: string - last_name: string_span - data: list( list( float32)) - description: - toxicity: float32 - sentences: list( - len: int32) -addresses: list( - city: string - zipcode: int16 - current: boolean - locations: list( - latitude: float16 - longitude: float64)) -blob: binary\ -""" diff --git a/lilac/server_concept_test.py b/lilac/server_concept_test.py deleted file mode 100644 index 58ab35f95f07f9a1973992f944e98c095aa31585..0000000000000000000000000000000000000000 --- a/lilac/server_concept_test.py +++ /dev/null @@ -1,456 +0,0 @@ -"""Test the public REST API for concepts.""" -import os -import uuid -from pathlib import Path -from typing import Iterable, cast - -import numpy as np -import pytest -from fastapi.testclient import TestClient -from pydantic import parse_obj_as -from pytest_mock import MockerFixture -from typing_extensions import override - -from .concepts.concept import ( - DRAFT_MAIN, - Concept, - ConceptModel, - ConceptType, - Example, - ExampleIn, - ExampleOrigin, -) -from .concepts.db_concept import ConceptACL, ConceptInfo, ConceptUpdate -from .router_concept import ( - ConceptModelInfo, - CreateConceptOptions, - MergeConceptDraftOptions, - ScoreBody, - ScoreExample, -) -from .schema import Item, RichData, lilac_embedding, lilac_span -from .server import app -from .signal import TextEmbeddingSignal, clear_signal_registry, register_signal -from .test_utils import fake_uuid - -client = TestClient(app) - -EMBEDDINGS: list[tuple[str, list[float]]] = [('hello', [1.0, 0.0, 0.0]), ('hello2', [1.0, 1.0, - 0.0]), - ('hello world', [1.0, 1.0, 1.0]), - ('hello world2', [2.0, 1.0, 1.0])] - -STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Call the embedding function.""" - for example in data: - yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_signal(TestEmbedding) - # Unit test runs. - yield - # Teardown. - clear_signal_registry() - - -@pytest.fixture(scope='function', autouse=True) -def setup_data_dir(tmp_path: Path, mocker: MockerFixture) -> None: - mocker.patch.dict(os.environ, {'LILAC_DATA_PATH': str(tmp_path)}) - - -def _remove_lilac_concepts(concepts: list[ConceptInfo]) -> list[ConceptInfo]: - return list(filter(lambda c: c.namespace != 'lilac', concepts)) - - -def test_list_lilac_concepts() -> None: - url = '/api/v1/concepts/' - response = client.get(url) - - assert response.status_code == 200 - # Make sure lilac concepts exist. - assert filter(lambda c: c.concept_name == 'positive-sentiment' and c.namespace == 'lilac', - response.json()) - - -def test_concept_create() -> None: - url = '/api/v1/concepts/' - response = client.get(url) - - assert response.status_code == 200 - response_concepts = _remove_lilac_concepts(parse_obj_as(list[ConceptInfo], response.json())) - assert response_concepts == [] - - # Create a concept. - url = '/api/v1/concepts/create' - create_concept = CreateConceptOptions( - namespace='concept_namespace', name='concept', type=ConceptType.TEXT) - response = client.post(url, json=create_concept.dict()) - assert response.status_code == 200 - assert Concept.parse_obj(response.json()) == Concept( - namespace='concept_namespace', - concept_name='concept', - type=ConceptType.TEXT, - data={}, - version=0) - - # Make sure list shows us the new concept. - url = '/api/v1/concepts/' - response = client.get(url) - assert response.status_code == 200 - response_concepts = _remove_lilac_concepts(parse_obj_as(list[ConceptInfo], response.json())) - assert response_concepts == [ - ConceptInfo( - namespace='concept_namespace', - name='concept', - type=ConceptType.TEXT, - drafts=[DRAFT_MAIN], - acls=ConceptACL(read=True, write=True)) - ] - - -def test_concept_delete() -> None: - # Create a concept. - client.post( - '/api/v1/concepts/create', - json=CreateConceptOptions(namespace='concept_namespace', name='concept', - type=ConceptType.TEXT).dict()) - - response = client.get('/api/v1/concepts/') - response_concepts = _remove_lilac_concepts(parse_obj_as(list[ConceptInfo], response.json())) - assert len(response_concepts) == 1 - - # Delete the concept. - url = '/api/v1/concepts/concept_namespace/concept' - response = client.delete(url) - assert response.status_code == 200 - - # Make sure list shows no concepts. - response = client.get('/api/v1/concepts/') - response_concepts = _remove_lilac_concepts(parse_obj_as(list[ConceptInfo], response.json())) - assert response_concepts == [] - - -def test_concept_edits(mocker: MockerFixture) -> None: - mock_uuid = mocker.patch.object(uuid, 'uuid4', autospec=True) - - # Create the concept. - response = client.post( - '/api/v1/concepts/create', - json=CreateConceptOptions(namespace='concept_namespace', name='concept', - type=ConceptType.TEXT).dict()) - - # Make sure we can add an example. - mock_uuid.return_value = fake_uuid(b'1') - url = '/api/v1/concepts/concept_namespace/concept' - concept_update = ConceptUpdate(insert=[ - ExampleIn( - label=True, - text='hello', - origin=ExampleOrigin( - dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')) - ]) - response = client.post(url, json=concept_update.dict()) - assert response.status_code == 200 - assert Concept.parse_obj(response.json()) == Concept( - namespace='concept_namespace', - concept_name='concept', - type=ConceptType.TEXT, - data={ - fake_uuid(b'1').hex: Example( - id=fake_uuid(b'1').hex, - label=True, - text='hello', - origin=ExampleOrigin( - dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')) - }, - version=1) - - url = '/api/v1/concepts/' - response = client.get(url) - - assert response.status_code == 200 - response_concepts = _remove_lilac_concepts(parse_obj_as(list[ConceptInfo], response.json())) - assert response_concepts == [ - ConceptInfo( - namespace='concept_namespace', - name='concept', - type=ConceptType.TEXT, - drafts=[DRAFT_MAIN], - acls=ConceptACL(read=True, write=True)) - ] - - # Add another example. - mock_uuid.return_value = fake_uuid(b'2') - url = '/api/v1/concepts/concept_namespace/concept' - concept_update = ConceptUpdate(insert=[ - ExampleIn( - label=True, - text='hello2', - origin=ExampleOrigin( - dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d2')) - ]) - response = client.post(url, json=concept_update.dict()) - assert response.status_code == 200 - assert Concept.parse_obj(response.json()) == Concept( - namespace='concept_namespace', - concept_name='concept', - type=ConceptType.TEXT, - data={ - fake_uuid(b'1').hex: Example( - id=fake_uuid(b'1').hex, - label=True, - text='hello', - origin=ExampleOrigin( - dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')), - fake_uuid(b'2').hex: Example( - id=fake_uuid(b'2').hex, - label=True, - text='hello2', - origin=ExampleOrigin( - dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d2')) - }, - version=2) - - # Edit both examples. - url = '/api/v1/concepts/concept_namespace/concept' - concept_update = ConceptUpdate(update=[ - # Switch the label. - Example(id=fake_uuid(b'1').hex, label=False, text='hello'), - # Switch the text. - Example(id=fake_uuid(b'2').hex, label=True, text='hello world'), - ]) - response = client.post(url, json=concept_update.dict()) - assert response.status_code == 200 - assert Concept.parse_obj(response.json()) == Concept( - namespace='concept_namespace', - concept_name='concept', - type=ConceptType.TEXT, - data={ - fake_uuid(b'1').hex: Example(id=fake_uuid(b'1').hex, label=False, text='hello'), - fake_uuid(b'2').hex: Example(id=fake_uuid(b'2').hex, label=True, text='hello world') - }, - version=3) - - # Delete the first example. - url = '/api/v1/concepts/concept_namespace/concept' - concept_update = ConceptUpdate(remove=[fake_uuid(b'1').hex]) - response = client.post(url, json=concept_update.dict()) - assert response.status_code == 200 - assert Concept.parse_obj(response.json()) == Concept( - namespace='concept_namespace', - concept_name='concept', - type=ConceptType.TEXT, - data={fake_uuid(b'2').hex: Example(id=fake_uuid(b'2').hex, label=True, text='hello world')}, - version=4) - - # The concept still exists. - url = '/api/v1/concepts/' - response = client.get(url) - - assert response.status_code == 200 - response_concepts = _remove_lilac_concepts(parse_obj_as(list[ConceptInfo], response.json())) - assert response_concepts == [ - ConceptInfo( - namespace='concept_namespace', - name='concept', - type=ConceptType.TEXT, - drafts=[DRAFT_MAIN], - acls=ConceptACL(read=True, write=True)) - ] - - -def test_concept_drafts(mocker: MockerFixture) -> None: - mock_uuid = mocker.patch.object(uuid, 'uuid4', autospec=True) - - # Create the concept. - response = client.post( - '/api/v1/concepts/create', - json=CreateConceptOptions(namespace='concept_namespace', name='concept', - type=ConceptType.TEXT).dict()) - - # Add examples, some drafts. - mock_uuid.side_effect = [fake_uuid(b'1'), fake_uuid(b'2'), fake_uuid(b'3'), fake_uuid(b'4')] - url = '/api/v1/concepts/concept_namespace/concept' - concept_update = ConceptUpdate(insert=[ - ExampleIn(label=True, text='in concept'), - ExampleIn(label=False, text='out of concept'), - ExampleIn(label=False, text='in concept', draft='test_draft'), - ExampleIn(label=False, text='out of concept draft', draft='test_draft') - ]) - response = client.post(url, json=concept_update.dict()) - assert response.status_code == 200 - - # Make sure list shows us the drafts - url = '/api/v1/concepts/' - response = client.get(url) - assert response.status_code == 200 - # Remove lilac concepts for the test. - concepts = list( - filter(lambda c: c.namespace != 'lilac', parse_obj_as(list[ConceptInfo], response.json()))) - - assert concepts == [ - ConceptInfo( - namespace='concept_namespace', - name='concept', - type=ConceptType.TEXT, - drafts=[DRAFT_MAIN, 'test_draft'], - acls=ConceptACL(read=True, write=True)) - ] - - # Make sure when we request main, we only get data in main. - url = '/api/v1/concepts/concept_namespace/concept' - response = client.get(url) - assert response.status_code == 200 - assert Concept.parse_obj(response.json()) == Concept( - namespace='concept_namespace', - concept_name='concept', - type=ConceptType.TEXT, - data={ - # Only main are returned. - fake_uuid(b'1').hex: Example(id=fake_uuid(b'1').hex, label=True, text='in concept'), - fake_uuid(b'2').hex: Example(id=fake_uuid(b'2').hex, label=False, text='out of concept') - }, - version=1) - - # Make sure when we request the draft, we get the draft data deduped with main. - url = '/api/v1/concepts/concept_namespace/concept?draft=test_draft' - response = client.get(url) - assert response.status_code == 200 - assert Concept.parse_obj(response.json()) == Concept( - namespace='concept_namespace', - concept_name='concept', - type=ConceptType.TEXT, - data={ - # b'1' is deduped with b'3'. - fake_uuid(b'2').hex: Example(id=fake_uuid(b'2').hex, label=False, text='out of concept'), - # ID 3 is a duplicate of main's 1. - fake_uuid(b'3').hex: Example( - id=fake_uuid(b'3').hex, label=False, text='in concept', draft='test_draft'), - fake_uuid(b'4').hex: Example( - id=fake_uuid(b'4').hex, label=False, text='out of concept draft', draft='test_draft') - }, - version=1) - - # Merge the draft. - response = client.post( - '/api/v1/concepts/concept_namespace/concept/merge_draft', - json=MergeConceptDraftOptions(draft='test_draft').dict()) - assert response.status_code == 200 - - # Make sure we get the merged drafts. - url = '/api/v1/concepts/concept_namespace/concept' - response = client.get(url) - assert response.status_code == 200 - assert Concept.parse_obj(response.json()).dict() == Concept( - namespace='concept_namespace', - concept_name='concept', - type=ConceptType.TEXT, - data={ - # b'1' is deduped with b'3'. - fake_uuid(b'2').hex: Example(id=fake_uuid(b'2').hex, label=False, text='out of concept'), - # ID 3 is a duplicate of main's 1. - fake_uuid(b'3').hex: Example(id=fake_uuid(b'3').hex, label=False, text='in concept'), - fake_uuid(b'4').hex: Example( - id=fake_uuid(b'4').hex, label=False, text='out of concept draft') - }, - version=2).dict() - - -def test_concept_model_sync(mocker: MockerFixture) -> None: - mock_uuid = mocker.patch.object(uuid, 'uuid4', autospec=True) - - # Create the concept. - response = client.post( - '/api/v1/concepts/create', - json=CreateConceptOptions(namespace='concept_namespace', name='concept', - type=ConceptType.TEXT).dict()) - - # Add two examples. - mock_uuid.side_effect = [fake_uuid(b'1'), fake_uuid(b'2')] - url = '/api/v1/concepts/concept_namespace/concept' - concept_update = ConceptUpdate(insert=[ - ExampleIn( - label=True, - text='hello', - origin=ExampleOrigin( - dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')), - ExampleIn( - label=False, - text='hello world', - origin=ExampleOrigin( - dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d2')) - ]) - response = client.post(url, json=concept_update.dict()) - assert response.status_code == 200 - - # Get the concept model, without creating it. - url = '/api/v1/concepts/concept_namespace/concept/model/test_embedding' - response = client.get(url, params={'create_if_not_exists': False}) - assert response.status_code == 200 - assert response.json() is None - - # Get the concept model, and create it. - url = '/api/v1/concepts/concept_namespace/concept/model/test_embedding' - response = client.get(url, params={'create_if_not_exists': True}) - assert response.status_code == 200 - assert ConceptModelInfo.parse_obj(response.json()) == ConceptModelInfo( - namespace='concept_namespace', - concept_name='concept', - embedding_name='test_embedding', - version=1) - - # Score an example. - mock_score_emb = mocker.patch.object(ConceptModel, 'score_embeddings', autospec=True) - # The return value here is a batch of values. - mock_score_emb.return_value = np.array([0.9, 1.0]) - url = '/api/v1/concepts/concept_namespace/concept/model/test_embedding/score' - score_body = ScoreBody(examples=[ScoreExample(text='hello world'), ScoreExample(text='hello')]) - response = client.post(url, json=score_body.dict()) - assert response.status_code == 200 - assert response.json() == [[lilac_span(0, 11, {'score': 0.9})], - [lilac_span(0, 5, {'score': 1.0})]] - - -def test_concept_edits_error_before_create(mocker: MockerFixture) -> None: - url = '/api/v1/concepts/concept_namespace/concept' - concept_update = ConceptUpdate(insert=[ - ExampleIn( - label=True, - text='hello', - origin=ExampleOrigin( - dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')) - ]) - response = client.post(url, json=concept_update.dict()) - assert response.is_error is True - assert response.status_code == 500 - - -def test_concept_edits_wrong_type(mocker: MockerFixture) -> None: - # Create the concept. - response = client.post( - '/api/v1/concepts/create', - json=CreateConceptOptions( - namespace='concept_namespace', name='concept', type=ConceptType.IMAGE).dict()) - - url = '/api/v1/concepts/concept_namespace/concept' - concept_update = ConceptUpdate(insert=[ - ExampleIn( - label=True, - text='hello', - origin=ExampleOrigin( - dataset_namespace='dataset_namespace', dataset_name='dataset', dataset_row_id='d1')) - ]) - response = client.post(url, json=concept_update.dict()) - assert response.is_error is True - assert response.status_code == 500 diff --git a/lilac/server_dataset_test.py b/lilac/server_dataset_test.py deleted file mode 100644 index 750e8683c920dd6ac024b9ec01399420ce609148..0000000000000000000000000000000000000000 --- a/lilac/server_dataset_test.py +++ /dev/null @@ -1,304 +0,0 @@ -"""Test our public REST API.""" -import os -from typing import Iterable, Optional, Type - -import pytest -from fastapi.testclient import TestClient -from pytest_mock import MockerFixture - -from .config import DatasetSettings -from .data.dataset import Dataset, DatasetManifest, SelectRowsSchemaResult, SelectRowsSchemaUDF -from .data.dataset_duckdb import DatasetDuckDB -from .data.dataset_test_utils import TEST_DATASET_NAME, TEST_NAMESPACE, enriched_item, make_dataset -from .router_dataset import ( - Column, - ComputeSignalOptions, - DeleteSignalOptions, - SelectRowsOptions, - SelectRowsResponse, - SelectRowsSchemaOptions, - WebManifest, -) -from .schema import Field, Item, RichData, field, schema -from .server import app -from .signal import TextSignal, clear_signal_registry, register_signal - -client = TestClient(app) - -DATASET_CLASSES = [DatasetDuckDB] - -TEST_DATA: list[Item] = [{ - 'erased': False, - 'people': [{ - 'name': 'A', - 'zipcode': 0, - 'locations': [{ - 'city': 'city1', - 'state': 'state1' - }, { - 'city': 'city2', - 'state': 'state2' - }] - }] -}, { - 'erased': True, - 'people': [{ - 'name': 'B', - 'zipcode': 1, - 'locations': [{ - 'city': 'city3', - 'state': 'state3' - }, { - 'city': 'city4' - }, { - 'city': 'city5' - }] - }, { - 'name': 'C', - 'zipcode': 2, - 'locations': [{ - 'city': 'city1', - 'state': 'state1' - }] - }] -}, { - 'erased': True -}] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_signal(LengthSignal) - # Unit test runs. - yield - # Teardown. - clear_signal_registry() - - -@pytest.fixture(scope='module', autouse=True, params=DATASET_CLASSES) -def test_data(tmp_path_factory: pytest.TempPathFactory, module_mocker: MockerFixture, - request: pytest.FixtureRequest) -> None: - tmp_path = tmp_path_factory.mktemp('data') - module_mocker.patch.dict(os.environ, {'LILAC_DATA_PATH': str(tmp_path)}) - - dataset_cls: Type[Dataset] = request.param - make_dataset(dataset_cls, tmp_path, TEST_DATA) - - -def test_get_manifest() -> None: - url = f'/api/v1/datasets/{TEST_NAMESPACE}/{TEST_DATASET_NAME}' - response = client.get(url) - assert response.status_code == 200 - assert WebManifest.parse_obj(response.json()) == WebManifest( - dataset_manifest=DatasetManifest( - namespace=TEST_NAMESPACE, - dataset_name=TEST_DATASET_NAME, - data_schema=schema({ - 'erased': 'boolean', - 'people': [{ - 'name': 'string', - 'zipcode': 'int32', - 'locations': [{ - 'city': 'string', - 'state': 'string' - }] - }] - }), - num_items=3)) - - -def test_select_rows_no_options() -> None: - url = f'/api/v1/datasets/{TEST_NAMESPACE}/{TEST_DATASET_NAME}/select_rows' - options = SelectRowsOptions() - response = client.post(url, json=options.dict()) - assert response.status_code == 200 - assert SelectRowsResponse.parse_obj(response.json()) == SelectRowsResponse( - rows=TEST_DATA, total_num_rows=3) - - -def test_select_rows_with_cols_and_limit() -> None: - url = f'/api/v1/datasets/{TEST_NAMESPACE}/{TEST_DATASET_NAME}/select_rows' - options = SelectRowsOptions( - columns=[('people', '*', 'zipcode'), ('people', '*', 'locations', '*', 'city')], - limit=1, - offset=1) - response = client.post(url, json=options.dict()) - assert response.status_code == 200 - assert SelectRowsResponse.parse_obj(response.json()) == SelectRowsResponse( - rows=[{ - 'people.*.zipcode': [1, 2], - 'people.*.locations.*.city': [['city3', 'city4', 'city5'], ['city1']] - }], - total_num_rows=3) - - -def test_select_rows_with_cols_and_combine() -> None: - url = f'/api/v1/datasets/{TEST_NAMESPACE}/{TEST_DATASET_NAME}/select_rows' - options = SelectRowsOptions( - columns=[('people', '*', 'zipcode'), ('people', '*', 'locations', '*', 'city')], - combine_columns=True) - response = client.post(url, json=options.dict()) - assert response.status_code == 200 - assert SelectRowsResponse.parse_obj(response.json()) == SelectRowsResponse( - rows=[{ - 'people': [{ - 'zipcode': 0, - 'locations': [{ - 'city': 'city1', - }, { - 'city': 'city2', - }] - }] - }, { - 'people': [{ - 'zipcode': 1, - 'locations': [{ - 'city': 'city3', - }, { - 'city': 'city4' - }, { - 'city': 'city5' - }] - }, { - 'zipcode': 2, - 'locations': [{ - 'city': 'city1' - }] - }] - }, { - 'people': None - }], - total_num_rows=3) - - -class LengthSignal(TextSignal): - name = 'length_signal' - - def fields(self) -> Field: - return field('int32') - - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - for text_content in data: - yield len(text_content) if text_content is not None else None - - -def test_select_rows_star_plus_udf() -> None: - url = f'/api/v1/datasets/{TEST_NAMESPACE}/{TEST_DATASET_NAME}/select_rows' - udf = Column(path=('people', '*', 'name'), alias='len', signal_udf=LengthSignal()) - options = SelectRowsOptions(columns=['*', udf], combine_columns=True) - response = client.post(url, json=options.dict()) - assert response.status_code == 200 - assert SelectRowsResponse.parse_obj(response.json()) == SelectRowsResponse( - rows=[{ - 'erased': False, - 'people': [{ - 'name': enriched_item('A', {'length_signal': 1}), - 'zipcode': 0, - 'locations': [{ - 'city': 'city1', - 'state': 'state1' - }, { - 'city': 'city2', - 'state': 'state2' - }] - }] - }, { - 'erased': True, - 'people': [{ - 'name': enriched_item('B', {'length_signal': 1}), - 'zipcode': 1, - 'locations': [{ - 'city': 'city3', - 'state': 'state3' - }, { - 'city': 'city4' - }, { - 'city': 'city5' - }] - }, { - 'name': enriched_item('C', {'length_signal': 1}), - 'zipcode': 2, - 'locations': [{ - 'city': 'city1', - 'state': 'state1' - }] - }] - }, { - 'erased': True - }], - total_num_rows=3) - - -def test_select_rows_schema_star_plus_udf() -> None: - url = f'/api/v1/datasets/{TEST_NAMESPACE}/{TEST_DATASET_NAME}/select_rows_schema' - signal = LengthSignal() - udf = Column(path=('people', '*', 'name'), alias='len', signal_udf=signal) - options = SelectRowsSchemaOptions(columns=['*', udf], combine_columns=True) - response = client.post(url, json=options.dict()) - assert response.status_code == 200 - assert SelectRowsSchemaResult.parse_obj(response.json()) == SelectRowsSchemaResult( - data_schema=schema({ - 'erased': 'boolean', - 'people': [{ - 'name': field( - 'string', fields={'length_signal': field('int32', signal.dict(exclude_none=True))}), - 'zipcode': 'int32', - 'locations': [{ - 'city': 'string', - 'state': 'string' - }] - }] - }), - udfs=[SelectRowsSchemaUDF(path=('people', '*', 'name', 'length_signal'), alias='len')]) - - -def test_select_rows_schema_no_cols() -> None: - url = f'/api/v1/datasets/{TEST_NAMESPACE}/{TEST_DATASET_NAME}/select_rows_schema' - options = SelectRowsSchemaOptions(combine_columns=True) - response = client.post(url, json=options.dict()) - assert response.status_code == 200 - assert SelectRowsSchemaResult.parse_obj(response.json()) == SelectRowsSchemaResult( - data_schema=schema({ - 'erased': 'boolean', - 'people': [{ - 'name': 'string', - 'zipcode': 'int32', - 'locations': [{ - 'city': 'string', - 'state': 'string' - }] - }] - })) - - -def test_compute_signal_auth(mocker: MockerFixture) -> None: - mocker.patch.dict(os.environ, {'LILAC_AUTH_ENABLED': 'True'}) - - url = f'/api/v1/datasets/{TEST_NAMESPACE}/{TEST_DATASET_NAME}/compute_signal' - response = client.post( - url, json=ComputeSignalOptions(signal=LengthSignal(), leaf_path=('people', 'name')).dict()) - assert response.status_code == 401 - assert response.is_error is True - assert 'User does not have access to compute signals over this dataset.' in response.text - - -def test_delete_signal_auth(mocker: MockerFixture) -> None: - mocker.patch.dict(os.environ, {'LILAC_AUTH_ENABLED': 'True'}) - - url = f'/api/v1/datasets/{TEST_NAMESPACE}/{TEST_DATASET_NAME}/delete_signal' - response = client.request( - 'DELETE', url, json=DeleteSignalOptions(signal_path=('doesnt', 'matter')).dict()) - assert response.status_code == 401 - assert response.is_error is True - assert 'User does not have access to delete this signal.' in response.text - - -def test_update_settings_auth(mocker: MockerFixture) -> None: - mocker.patch.dict(os.environ, {'LILAC_AUTH_ENABLED': 'True'}) - - url = f'/api/v1/datasets/{TEST_NAMESPACE}/{TEST_DATASET_NAME}/settings' - response = client.post(url, json=DatasetSettings().dict()) - assert response.status_code == 401 - assert response.is_error is True - assert 'User does not have access to update the settings of this dataset.' in response.text diff --git a/lilac/server_signal_test.py b/lilac/server_signal_test.py deleted file mode 100644 index 2894939d5e6208ea9a351d3651979fb653178761..0000000000000000000000000000000000000000 --- a/lilac/server_signal_test.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Test the public REST API for signals.""" -import os -from pathlib import Path -from typing import Iterable, Optional - -import pytest -from fastapi.testclient import TestClient -from pytest_mock import MockerFixture -from typing_extensions import override - -from .router_signal import ( - SignalComputeOptions, - SignalComputeResponse, - SignalSchemaOptions, - SignalSchemaResponse, -) -from .schema import Field, Item, RichData, SignalInputType, field -from .server import app -from .signal import Signal, clear_signal_registry, register_signal - -client = TestClient(app) - -EMBEDDINGS: list[tuple[str, list[float]]] = [('hello', [1.0, 0.0, 0.0]), ('hello2', [1.0, 1.0, - 0.0]), - ('hello world', [1.0, 1.0, 1.0]), - ('hello world2', [2.0, 1.0, 1.0])] - -STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS} - - -class TestQueryAndLengthSignal(Signal): - """A test signal.""" - - # Pydantic fields - name = 'test_signal' - input_type = SignalInputType.TEXT - - query: str - - @override - def fields(self) -> Field: - return field('int32') - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - return [f'{self.query}_{len(e)}' for e in data] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_signal(TestQueryAndLengthSignal) - # Unit test runs. - yield - # Teardown. - clear_signal_registry() - - -@pytest.fixture(scope='function', autouse=True) -def setup_data_dir(tmp_path: Path, mocker: MockerFixture) -> None: - mocker.patch.dict(os.environ, {'LILAC_DATA_PATH': str(tmp_path)}) - - -def test_compute() -> None: - # Compute the signal. - url = '/api/v1/signals/compute' - create_signal = SignalComputeOptions( - signal=TestQueryAndLengthSignal(query='hi'), inputs=['hello', 'hello2']) - response = client.post(url, json=create_signal.dict()) - assert response.status_code == 200 - assert SignalComputeResponse.parse_obj( - response.json()) == SignalComputeResponse(items=['hi_5', 'hi_6']) - - -def test_schema() -> None: - # Get the schema for the signal. - url = '/api/v1/signals/schema' - signal = TestQueryAndLengthSignal(query='hi') - create_signal = SignalSchemaOptions(signal=signal) - response = client.post(url, json=create_signal.dict()) - assert response.status_code == 200 - assert SignalSchemaResponse.parse_obj( - response.json()) == SignalSchemaResponse(fields=signal.fields()) diff --git a/lilac/signal_test.py b/lilac/signal_test.py deleted file mode 100644 index 1c2238e9f377a5e3c4cb8b985b76e7808813edca..0000000000000000000000000000000000000000 --- a/lilac/signal_test.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Test signal base class.""" -from typing import Iterable, Optional - -import pytest -from typing_extensions import override - -from .schema import Field, Item, RichData, SignalInputType, field -from .signal import ( - Signal, - TextEmbeddingSignal, - TextSplitterSignal, - clear_signal_registry, - get_signal_by_type, - get_signal_cls, - get_signals_by_type, - register_signal, - resolve_signal, -) - - -class TestSignal(Signal): - """A test signal.""" - - # Pydantic fields - name = 'test_signal' - input_type = SignalInputType.TEXT - - query: str - - @override - def fields(self) -> Field: - return field('float32') - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - del data - return [] - - -class TestTextSplitter(TextSplitterSignal): - """A test text splitter.""" - name = 'test_splitter' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - del data - return [] - - -class TestTextEmbedding(TextEmbeddingSignal): - """A test text embedding.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]: - del data - return [] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_signal(TestSignal) - register_signal(TestTextSplitter) - register_signal(TestTextEmbedding) - - # Unit test runs. - yield - - # Teardown. - clear_signal_registry() - - -def test_signal_serialization() -> None: - signal = TestSignal(query='test') - - # The class variables should not be included. - assert signal.dict() == {'signal_name': 'test_signal', 'query': 'test'} - - -def test_get_signal_cls() -> None: - """Test getting a signal.""" - assert TestSignal == get_signal_cls('test_signal') - - -def test_resolve_signal() -> None: - """Test resolving a signal.""" - test_signal = TestSignal(query='hello') - - # Signals pass through. - assert resolve_signal(test_signal) == test_signal - - # Dicts resolve to the base class. - assert resolve_signal(test_signal.dict()) == test_signal - - -def test_get_signal_by_type() -> None: - assert get_signal_by_type(TestTextSplitter.name, TextSplitterSignal) == TestTextSplitter - assert get_signal_by_type(TestTextEmbedding.name, TextEmbeddingSignal) == TestTextEmbedding - - -def test_get_signal_by_type_validation() -> None: - with pytest.raises(ValueError, match='Signal "invalid_signal" not found in the registry'): - get_signal_by_type('invalid_signal', TextSplitterSignal) - - with pytest.raises( - ValueError, match=f'"{TestTextSplitter.name}" is a `{TestTextSplitter.__name__}`'): - get_signal_by_type(TestTextSplitter.name, TextEmbeddingSignal) - - -def test_get_signals_by_type() -> None: - assert get_signals_by_type(TextSplitterSignal) == [TestTextSplitter] - assert get_signals_by_type(TextEmbeddingSignal) == [TestTextEmbedding] - - -class TestSignalNoDisplayName(Signal): - name = 'signal_no_name' - - -class TestSignalDisplayName(Signal): - name = 'signal_display_name' - display_name = 'test display name' - - -def test_signal_title_schema() -> None: - assert TestSignalNoDisplayName.schema()['title'] == TestSignalNoDisplayName.__name__ - assert TestSignalDisplayName.schema()['title'] == 'test display name' diff --git a/lilac/signals/__pycache__/__init__.cpython-39.pyc b/lilac/signals/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index a09e855fb917f0026e987c848032291cb82770b5..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/concept_labels.cpython-39.pyc b/lilac/signals/__pycache__/concept_labels.cpython-39.pyc deleted file mode 100644 index f405ff1a87de019661978582716ff3d4dc0a1c41..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/concept_labels.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/concept_labels_test.cpython-39-pytest-7.4.0.pyc b/lilac/signals/__pycache__/concept_labels_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 0e3feb138b99355857870f522f0cdb45b0a4749e..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/concept_labels_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/concept_scorer.cpython-39.pyc b/lilac/signals/__pycache__/concept_scorer.cpython-39.pyc deleted file mode 100644 index e322b076cab22a0544c1e38f67c770fd48772680..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/concept_scorer.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/concept_scorer_test.cpython-39-pytest-7.4.0.pyc b/lilac/signals/__pycache__/concept_scorer_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index b01085ba0e37a5d66dc23782f9e28b3f083c9cd8..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/concept_scorer_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/default_signals.cpython-39.pyc b/lilac/signals/__pycache__/default_signals.cpython-39.pyc deleted file mode 100644 index 1b9174f66f7248710ea59e76448703efa954ed50..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/default_signals.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/lang_detection.cpython-39.pyc b/lilac/signals/__pycache__/lang_detection.cpython-39.pyc deleted file mode 100644 index b3732e05a20a84ced8570a4800967b7057e63afe..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/lang_detection.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/lang_detection_test.cpython-39-pytest-7.4.0.pyc b/lilac/signals/__pycache__/lang_detection_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index b01b22ca250ec066f11a6728874d437097afd337..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/lang_detection_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/minhash_dup.cpython-39.pyc b/lilac/signals/__pycache__/minhash_dup.cpython-39.pyc deleted file mode 100644 index 7ece4e75a46f098846e292f9bd1d725dfff7db22..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/minhash_dup.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/near_dup.cpython-39.pyc b/lilac/signals/__pycache__/near_dup.cpython-39.pyc deleted file mode 100644 index fc991d32fc20f22c32a008710ca1cc3fadd69e3f..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/near_dup.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/near_dup_test.cpython-39-pytest-7.4.0.pyc b/lilac/signals/__pycache__/near_dup_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 5cad09729bfc1b9710fe9cf9dcd9affe137f06a1..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/near_dup_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/ner.cpython-39.pyc b/lilac/signals/__pycache__/ner.cpython-39.pyc deleted file mode 100644 index 35bfd1a2ee2dc873d99348853123a0d985bbfac7..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/ner.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/ner_test.cpython-39-pytest-7.4.0.pyc b/lilac/signals/__pycache__/ner_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 1980bacc440f00abbbd589f8c9f4f94e266d8c5c..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/ner_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/pii.cpython-39.pyc b/lilac/signals/__pycache__/pii.cpython-39.pyc deleted file mode 100644 index ecd55828f2ff252d0d64b6875991d031c339308d..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/pii.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/pii_ip_address.cpython-39.pyc b/lilac/signals/__pycache__/pii_ip_address.cpython-39.pyc deleted file mode 100644 index 11c26be2e3316666ba688230629ee7b6e58affc1..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/pii_ip_address.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/pii_secrets.cpython-39.pyc b/lilac/signals/__pycache__/pii_secrets.cpython-39.pyc deleted file mode 100644 index b2b375d9ef8ad61444c72aa9d5966553a6f886f2..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/pii_secrets.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/pii_test.cpython-39-pytest-7.4.0.pyc b/lilac/signals/__pycache__/pii_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 7f532cd71392ba7307356d9cf63aaba6a0d0f621..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/pii_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/semantic_similarity.cpython-39.pyc b/lilac/signals/__pycache__/semantic_similarity.cpython-39.pyc deleted file mode 100644 index b28308427a4a27dddbcc825eada2b131dd0ec784..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/semantic_similarity.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/semantic_similarity_test.cpython-39-pytest-7.4.0.pyc b/lilac/signals/__pycache__/semantic_similarity_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 0abdd0fafb07463bfc12a6853c9aa71a8682259c..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/semantic_similarity_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/substring_search.cpython-39.pyc b/lilac/signals/__pycache__/substring_search.cpython-39.pyc deleted file mode 100644 index c5a2722831c4d3e97fa49d23f4d1bc190b2ebe9d..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/substring_search.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/substring_search_test.cpython-39-pytest-7.4.0.pyc b/lilac/signals/__pycache__/substring_search_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 4b30c0e20ff019c55d7757c71858085ae295fbaa..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/substring_search_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/text_statistics.cpython-39.pyc b/lilac/signals/__pycache__/text_statistics.cpython-39.pyc deleted file mode 100644 index 67f8e19b0805b5b6b9c81a8822e1e7e6c508159d..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/text_statistics.cpython-39.pyc and /dev/null differ diff --git a/lilac/signals/__pycache__/text_statistics_test.cpython-39-pytest-7.4.0.pyc b/lilac/signals/__pycache__/text_statistics_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 692102fda8b72a2b683c2fd498c32022dfbe427f..0000000000000000000000000000000000000000 Binary files a/lilac/signals/__pycache__/text_statistics_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/signals/concept_labels_test.py b/lilac/signals/concept_labels_test.py deleted file mode 100644 index b155fbf5105bd2e8eb1306e9755071f6608df9b9..0000000000000000000000000000000000000000 --- a/lilac/signals/concept_labels_test.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Test for the concept label signal.""" - -import os -import pathlib -from typing import Generator, Type - -import pytest -from pytest_mock import MockerFixture - -from ..concepts.concept import ExampleIn -from ..concepts.db_concept import ConceptDB, ConceptUpdate, DiskConceptDB, DiskConceptModelDB -from ..data.dataset_duckdb import DatasetDuckDB -from ..db_manager import set_default_dataset_cls -from ..schema import SignalInputType, lilac_span -from ..signal import clear_signal_registry -from .concept_labels import ConceptLabelsSignal - -ALL_CONCEPT_DBS = [DiskConceptDB] -ALL_CONCEPT_MODEL_DBS = [DiskConceptModelDB] - - -@pytest.fixture(autouse=True) -def set_data_path(tmp_path: pathlib.Path, mocker: MockerFixture) -> None: - mocker.patch.dict(os.environ, {'LILAC_DATA_PATH': str(tmp_path)}) - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Generator: - # Setup. - set_default_dataset_cls(DatasetDuckDB) - - # Unit test runs. - yield - - # Teardown. - clear_signal_registry() - - -def test_concept_does_not_exist() -> None: - signal = ConceptLabelsSignal(namespace='test', concept_name='concept_doesnt_exist') - with pytest.raises(ValueError, match='Concept "test/concept_doesnt_exist" does not exist'): - list(signal.compute(['a new data point', 'not in concept'])) - - -@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS) -def test_concept_labels(concept_db_cls: Type[ConceptDB]) -> None: - concept_db = concept_db_cls() - namespace = 'test' - concept_name = 'test_concept' - concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) - - train_data = [ - ExampleIn(label=False, text='no in concept'), - ExampleIn(label=True, text='yes in concept'), - # This should never show since we request the main draft. - ExampleIn(label=False, text='this is unrelated', draft='test_draft') - ] - concept_db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept') - results = list( - signal.compute([ - 'this is no in concept', 'this is yes in concept', - 'this is no in concept. filler. this is yes in concept.', 'this is unrelated' - ])) - - assert results == [ - [lilac_span(8, 8 + len('no in concept'), {'label': False})], - [lilac_span(8, 8 + len('yes in concept'), {'label': True})], - [ - lilac_span(8, 8 + len('no in concept'), {'label': False}), - lilac_span(39, 39 + len('yes in concept'), {'label': True}) - ], - # This example is in the draft, which was not requested. - None - ] - - -@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS) -def test_concept_labels_draft(concept_db_cls: Type[ConceptDB]) -> None: - concept_db = concept_db_cls() - namespace = 'test' - concept_name = 'test_concept' - concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) - - concept_update = ConceptUpdate(insert=[ - ExampleIn(label=True, text='in concept'), - ExampleIn(label=False, text='out of concept'), - ExampleIn(label=True, text='in draft', draft='test_draft'), - ExampleIn(label=False, text='out draft', draft='test_draft') - ]) - - concept_db.edit(namespace, concept_name, concept_update) - - signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept', draft='test_draft') - results = list(signal.compute(['this is in concept', 'this is in draft', 'this is out draft'])) - - assert results == [[lilac_span(8, 8 + len('in concept'), {'label': True})], - [lilac_span(8, 8 + len('in draft'), { - 'label': True, - 'draft': 'test_draft' - })], - [lilac_span(8, 8 + len('out draft'), { - 'label': False, - 'draft': 'test_draft' - })]] - - -@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS) -def test_concept_labels_key(concept_db_cls: Type[ConceptDB]) -> None: - concept_db = concept_db_cls() - namespace = 'test' - concept_name = 'test_concept' - concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) - - signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept') - assert signal.key() == 'test/test_concept/labels' - - -@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS) -def test_concept_labels_compute_signal_key(concept_db_cls: Type[ConceptDB]) -> None: - concept_db = concept_db_cls() - namespace = 'test' - concept_name = 'test_concept' - concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) - - signal = ConceptLabelsSignal(namespace='test', concept_name='test_concept') - assert signal.key(is_computed_signal=True) == 'test/test_concept/labels/v0' diff --git a/lilac/signals/concept_scorer_test.py b/lilac/signals/concept_scorer_test.py deleted file mode 100644 index 8c16eef2e24dc456363d1ca025f265b8b2458ac3..0000000000000000000000000000000000000000 --- a/lilac/signals/concept_scorer_test.py +++ /dev/null @@ -1,271 +0,0 @@ -"""Test for the concept scorer.""" - -import os -import pathlib -from typing import Generator, Iterable, Type, cast - -import numpy as np -import pytest -from pytest_mock import MockerFixture -from typing_extensions import override - -from ..concepts.concept import ExampleIn -from ..concepts.db_concept import ( - ConceptDB, - ConceptModelDB, - ConceptUpdate, - DiskConceptDB, - DiskConceptModelDB, -) -from ..data.dataset_duckdb import DatasetDuckDB -from ..data.dataset_test_utils import make_vector_index -from ..db_manager import set_default_dataset_cls -from ..schema import Item, RichData, SignalInputType, lilac_embedding -from ..signal import TextEmbeddingSignal, clear_signal_registry, register_signal -from .concept_scorer import ConceptSignal - -ALL_CONCEPT_DBS = [DiskConceptDB] -ALL_CONCEPT_MODEL_DBS = [DiskConceptModelDB] -ALL_VECTOR_STORES = ['numpy', 'hnsw'] - - -@pytest.fixture(autouse=True) -def set_data_path(tmp_path: pathlib.Path, mocker: MockerFixture) -> None: - mocker.patch.dict(os.environ, {'LILAC_DATA_PATH': str(tmp_path)}) - - -EMBEDDING_MAP: dict[str, list[float]] = { - 'not in concept': [0.1, 0.9, 0.0], - 'in concept': [0.9, 0.1, 0.0], - 'a new data point': [0.1, 0.2, 0.3], - 'hello.': [0.1, 0.2, 0.3], - 'hello2.': [0.1, 0.2, 0.3], -} - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Embed the examples, use a hashmap to the vector for simplicity.""" - for example in data: - if example not in EMBEDDING_MAP: - raise ValueError(f'Example "{str(example)}" not in embedding map') - yield [lilac_embedding(0, len(example), np.array(EMBEDDING_MAP[cast(str, example)]))] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Generator: - # Setup. - set_default_dataset_cls(DatasetDuckDB) - register_signal(TestEmbedding) - - # Unit test runs. - yield - - # Teardown. - clear_signal_registry() - - -@pytest.mark.parametrize('db_cls', ALL_CONCEPT_DBS) -def test_embedding_does_not_exist(db_cls: Type[ConceptDB]) -> None: - db = db_cls() - namespace = 'test' - concept_name = 'test_concept' - db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) - - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept') - ] - db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - signal = ConceptSignal( - namespace='test', concept_name='test_concept', embedding='unknown_embedding') - with pytest.raises(ValueError, match='Signal "unknown_embedding" not found in the registry'): - signal.compute(['a new data point']) - - -def test_concept_does_not_exist() -> None: - signal = ConceptSignal(namespace='test', concept_name='test_concept', embedding='test_embedding') - with pytest.raises(ValueError, match='Concept "test/test_concept" does not exist'): - signal.compute(['a new data point', 'not in concept']) - - -@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS) -@pytest.mark.parametrize('model_db_cls', ALL_CONCEPT_MODEL_DBS) -def test_concept_model_score(concept_db_cls: Type[ConceptDB], - model_db_cls: Type[ConceptModelDB]) -> None: - concept_db = concept_db_cls() - model_db = model_db_cls(concept_db) - namespace = 'test' - concept_name = 'test_concept' - concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) - - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept') - ] - concept_db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - signal = ConceptSignal(namespace='test', concept_name='test_concept', embedding='test_embedding') - - # Explicitly sync the model with the concept. - model_db.sync( - namespace='test', concept_name='test_concept', embedding_name='test_embedding', create=True) - - result_items = list(signal.compute(['a new data point', 'not in concept'])) - scores = [result_item[0]['score'] for result_item in result_items if result_item] - assert scores[0] > 0 and scores[0] < 1 - assert scores[1] < 0.5 - - -@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS) -@pytest.mark.parametrize('model_db_cls', ALL_CONCEPT_MODEL_DBS) -@pytest.mark.parametrize('vector_store', ALL_VECTOR_STORES) -def test_concept_model_vector_score(concept_db_cls: Type[ConceptDB], - model_db_cls: Type[ConceptModelDB], vector_store: str) -> None: - concept_db = concept_db_cls() - model_db = model_db_cls(concept_db) - namespace = 'test' - concept_name = 'test_concept' - concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) - - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept') - ] - concept_db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - signal = ConceptSignal(namespace='test', concept_name='test_concept', embedding='test_embedding') - - # Explicitly sync the model with the concept. - model_db.sync( - namespace='test', concept_name='test_concept', embedding_name='test_embedding', create=True) - - vector_index = make_vector_index( - vector_store, { - ('1',): [EMBEDDING_MAP['in concept']], - ('2',): [EMBEDDING_MAP['not in concept']], - ('3',): [EMBEDDING_MAP['a new data point']], - }) - - scores = cast(list[Item], list(signal.vector_compute([('1',), ('2',), ('3',)], vector_index))) - assert scores[0][0]['score'] > 0.5 # '1' is in the concept. - assert scores[1][0]['score'] < 0.5 # '2' is not in the concept. - assert scores[2][0]['score'] > 0 and scores[2][0][ - 'score'] < 1 # '3' may or may not be in the concept. - - -@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS) -@pytest.mark.parametrize('model_db_cls', ALL_CONCEPT_MODEL_DBS) -@pytest.mark.parametrize('vector_store', ALL_VECTOR_STORES) -def test_concept_model_topk_score(concept_db_cls: Type[ConceptDB], - model_db_cls: Type[ConceptModelDB], vector_store: str) -> None: - concept_db = concept_db_cls() - model_db = model_db_cls(concept_db) - namespace = 'test' - concept_name = 'test_concept' - concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) - - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept') - ] - concept_db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - signal = ConceptSignal(namespace='test', concept_name='test_concept', embedding='test_embedding') - - # Explicitly sync the model with the concept. - model_db.sync( - namespace='test', concept_name='test_concept', embedding_name='test_embedding', create=True) - vector_index = make_vector_index(vector_store, { - ('1',): [[0.1, 0.2, 0.3]], - ('2',): [[0.1, 0.87, 0.0]], - ('3',): [[1.0, 0.0, 0.0]], - }) - - # Compute topk without id restriction. - topk_result = signal.vector_compute_topk(3, vector_index) - expected_result = [('3',), ('1',), ('2',)] - for (id, _), expected_id in zip(topk_result, expected_result): - assert id == expected_id - - # Compute top 1. - topk_result = signal.vector_compute_topk(1, vector_index) - expected_result = [('3',)] - for (id, _), expected_id in zip(topk_result, expected_result): - assert id == expected_id - - # Compute topk with id restriction. - topk_result = signal.vector_compute_topk(3, vector_index, keys=[('1',), ('2',)]) - expected_result = [('1',), ('2',)] - for (id, _), expected_id in zip(topk_result, expected_result): - assert id == expected_id - - -@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS) -@pytest.mark.parametrize('model_db_cls', ALL_CONCEPT_MODEL_DBS) -@pytest.mark.parametrize('vector_store', ALL_VECTOR_STORES) -def test_concept_model_draft(concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB], - vector_store: str) -> None: - concept_db = concept_db_cls() - model_db = model_db_cls(concept_db) - namespace = 'test' - concept_name = 'test_concept' - concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) - - train_data = [ - ExampleIn(label=False, text='not in concept'), - ExampleIn(label=True, text='in concept'), - ExampleIn(label=False, text='a new data point', draft='test_draft'), - ] - concept_db.edit(namespace, concept_name, ConceptUpdate(insert=train_data)) - - signal = ConceptSignal(namespace='test', concept_name='test_concept', embedding='test_embedding') - draft_signal = ConceptSignal( - namespace='test', concept_name='test_concept', embedding='test_embedding', draft='test_draft') - - # Explicitly sync the model with the concept. - model_db.sync( - namespace='test', concept_name='test_concept', embedding_name='test_embedding', create=True) - - vector_index = make_vector_index(vector_store, { - ('1',): [[1.0, 0.0, 0.0]], - ('2',): [[0.9, 0.1, 0.0]], - ('3',): [[0.1, 0.9, 0.0]], - }) - - scores = cast(list[Item], list(signal.vector_compute([('1',), ('2',), ('3',)], vector_index))) - assert scores[0][0]['score'] > 0.5 - assert scores[1][0]['score'] > 0.5 - assert scores[2][0]['score'] < 0.5 - - # Make sure the draft signal works. It has different values than the original signal. - vector_index = make_vector_index(vector_store, { - ('1',): [[1.0, 0.0, 0.0]], - ('2',): [[0.9, 0.1, 0.0]], - ('3',): [[0.1, 0.2, 0.3]], - }) - draft_scores = draft_signal.vector_compute([('1',), ('2',), ('3',)], vector_index) - assert draft_scores != scores - - -def test_concept_score_key() -> None: - signal = ConceptSignal( - namespace='test', concept_name='test_concept', embedding=TestEmbedding.name) - assert signal.key() == 'test/test_concept/test_embedding' - - -@pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS) -def test_concept_score_compute_signal_key(concept_db_cls: Type[ConceptDB]) -> None: - concept_db = concept_db_cls() - namespace = 'test' - concept_name = 'test_concept' - concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT) - - signal = ConceptSignal( - namespace='test', concept_name='test_concept', embedding=TestEmbedding.name) - assert signal.key(is_computed_signal=True) == 'test/test_concept/test_embedding/v0' diff --git a/lilac/signals/lang_detection_test.py b/lilac/signals/lang_detection_test.py deleted file mode 100644 index 592b7c43f42c4ff88280060b845469e6c193d163..0000000000000000000000000000000000000000 --- a/lilac/signals/lang_detection_test.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Tests for the language detection signal.""" - -from pytest_mock import MockerFixture - -from ..schema import lilac_span -from . import lang_detection -from .lang_detection import LANG_CODE, LangDetectionSignal - - -def test_lang_detection_sentences(mocker: MockerFixture) -> None: - signal = LangDetectionSignal() - mocker.patch(f'{lang_detection.__name__}.TEXT_LEN_THRESHOLD', 1) - signal.setup() - docs = [ - 'War doesnt show whos right, just whos left.', - 'Ein, zwei, drei, vier', - ] - res = list(signal.compute(docs)) - assert res == ['en', 'de'] - - -def test_lang_detection_multiple_paragraphs(mocker: MockerFixture) -> None: - signal = LangDetectionSignal(split_by_paragraph=True) - mocker.patch(f'{lang_detection.__name__}.TEXT_LEN_THRESHOLD', 1) - signal.setup() - doc = 'War doesnt show whos right, just whos left.\n\nEin, zwei, drei, vier' - res = list(signal.compute([doc])) - assert res == [[ - lilac_span(0, 43, {LANG_CODE: 'en'}), - lilac_span(45, 66, {LANG_CODE: 'de'}), - ]] - - -def test_text_too_short(mocker: MockerFixture) -> None: - signal = LangDetectionSignal() - mocker.patch(f'{lang_detection.__name__}.TEXT_LEN_THRESHOLD', 25) - signal.setup() - docs = [ - 'War doesnt show whos right, just whos left.', - 'Ein, zwei, drei, vier', - ] - res = list(signal.compute(docs)) - assert res == ['en', 'TOO_SHORT'] diff --git a/lilac/signals/near_dup_test.py b/lilac/signals/near_dup_test.py deleted file mode 100644 index c8c0975e72b660361f3cef93945cf2849daa4146..0000000000000000000000000000000000000000 --- a/lilac/signals/near_dup_test.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Test the Near duplicate signal.""" - -from .near_dup import CLUSTER_KEY, NearDuplicateSignal - - -def test_exact_duplicates() -> None: - signal = NearDuplicateSignal() - docs = ['Hello', 'Everyone', 'Hello', 'Hi'] - assert list(signal.compute(docs)) == [{CLUSTER_KEY: x} for x in [0, 1, 0, 3]] - - -def test_near_dups() -> None: - signal = NearDuplicateSignal() - docs = [ - 'Hello everyone. This is a test for near duplication with almost the same content', - 'Hello everyone. This is a test for near duplication with almost the same content [time]', - ] - assert list(signal.compute(docs)) == [{CLUSTER_KEY: x} for x in [0, 0]] diff --git a/lilac/signals/ner_test.py b/lilac/signals/ner_test.py deleted file mode 100644 index 62818696e9a22646ca42e8347b3fba093e3d5ce1..0000000000000000000000000000000000000000 --- a/lilac/signals/ner_test.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Test the Spacy NER signal.""" - -from ..schema import field -from ..splitters.text_splitter_test_utils import text_to_expected_spans -from .ner import SpacyNER - - -def test_spacy_ner_fields() -> None: - signal = SpacyNER() - signal.setup() - assert signal.fields() == field(fields=[field('string_span', fields={'label': 'string'})]) - - -def test_ner() -> None: - signal = SpacyNER() - signal.setup() - - text = ('Net income was $9.4 million compared to the prior year of $2.7 million.' - 'Revenue exceeded twelve billion dollars, with a loss of $1b.') - emails = list(signal.compute([text])) - - expected_spans = text_to_expected_spans(text, [ - ('$9.4 million', { - 'label': 'MONEY' - }), - ('the prior year', { - 'label': 'DATE' - }), - ('$2.7 million', { - 'label': 'MONEY' - }), - ('twelve billion dollars', { - 'label': 'MONEY' - }), - ('1b', { - 'label': 'MONEY' - }), - ]) - - assert emails == [expected_spans] diff --git a/lilac/signals/pii_test.py b/lilac/signals/pii_test.py deleted file mode 100644 index 31a338f61a2f495aeb5b6ccad62b39def1d3dff0..0000000000000000000000000000000000000000 --- a/lilac/signals/pii_test.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Test the PII signal.""" - -from ..schema import field -from ..splitters.text_splitter_test_utils import text_to_expected_spans -from .pii import EMAILS_KEY, IPS_KEY, SECRETS_KEY, PIISignal - - -def test_pii_fields() -> None: - signal = PIISignal() - assert signal.fields() == field(fields={ - EMAILS_KEY: ['string_span'], - IPS_KEY: ['string_span'], - SECRETS_KEY: ['string_span'] - }) - - -def test_pii_compute() -> None: - signal = PIISignal() - - text = 'This is an email nik@test.com. pii@gmail.com are where emails are read.' - emails = list(signal.compute([text])) - - expected_spans = text_to_expected_spans(text, ['nik@test.com', 'pii@gmail.com']) - - assert emails == [{EMAILS_KEY: expected_spans, IPS_KEY: [], SECRETS_KEY: []}] - - -def test_pii_case_insensitive() -> None: - signal = PIISignal() - - text = 'These are some emails: NIK@Test.com. pII@gmAIL.COM are where emails are read.' - emails = list(signal.compute([text])) - - expected_spans = text_to_expected_spans(text, ['NIK@Test.com', 'pII@gmAIL.COM']) - - assert emails == [{EMAILS_KEY: expected_spans, IPS_KEY: [], SECRETS_KEY: []}] - - -def test_ip_addresses() -> None: - signal = PIISignal() - - text = 'These are some ip addresses: 192.158.1.38 and 2001:db8:3333:4444:5555:6666:7777:8888' - pii = list(signal.compute([text])) - expected_spans = text_to_expected_spans( - text, ['192.158.1.38', '2001:db8:3333:4444:5555:6666:7777:8888']) - assert pii == [{EMAILS_KEY: [], IPS_KEY: expected_spans, SECRETS_KEY: []}] - - -def test_secrets() -> None: - signal = PIISignal() - - text = 'These are some secrets: AKIATESTTESTTESTTEST' - pii = list(signal.compute([text])) - expected_spans = text_to_expected_spans(text, ['AKIATESTTESTTESTTEST']) - assert pii == [{EMAILS_KEY: [], IPS_KEY: [], SECRETS_KEY: expected_spans}] diff --git a/lilac/signals/semantic_similarity_test.py b/lilac/signals/semantic_similarity_test.py deleted file mode 100644 index 20083302a35704c7cb599c02341f0e06ea630c53..0000000000000000000000000000000000000000 --- a/lilac/signals/semantic_similarity_test.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Test the semantic search signal.""" - -from typing import Iterable, Optional, cast - -import numpy as np -import pytest -from pytest_mock import MockerFixture -from typing_extensions import override - -from ..data.dataset_test_utils import make_vector_index -from ..embeddings.vector_store import VectorStore, register_vector_store -from ..schema import Item, RichData, VectorKey, lilac_embedding, lilac_span -from ..signal import TextEmbeddingSignal, clear_signal_registry, register_signal -from .semantic_similarity import SemanticSimilaritySignal - -EMBEDDINGS: dict[VectorKey, list[list[float]]] = { - ('1',): [[1.0, 0.0, 0.0]], - ('2',): [[0.9, 0.1, 0.0]], - ('3',): [[0.0, 0.0, 1.0]] -} - -STR_EMBEDDINGS: dict[str, list[float]] = { - 'hello': [1.0, 0.0, 0.0], - 'hello world': [0.9, 0.1, 0.0], - 'far': [0.0, 0.0, 1.0] -} - - -class TestVectorStore(VectorStore): - """A test vector store with fixed embeddings.""" - - name = 'test_vector_store' - - @override - def size(self) -> int: - return len(EMBEDDINGS) - - @override - def load(self, base_path: str) -> None: - raise NotImplementedError - - @override - def save(self, base_path: str) -> None: - raise NotImplementedError - - @override - def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: - # We fix the vectors for the test vector store. - pass - - @override - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: - keys = keys or [] - return np.array([EMBEDDINGS[tuple(path_key)][cast(int, index)] for *path_key, index in keys]) - - -class TestEmbedding(TextEmbeddingSignal): - """A test embed function.""" - name = 'test_embedding' - - @override - def compute(self, data: Iterable[RichData]) -> Iterable[Item]: - """Embed the examples, use a hashmap to the vector for simplicity.""" - for example in data: - yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))] - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_signal(TestEmbedding) - register_vector_store(TestVectorStore) - - # Unit test runs. - yield - - # Teardown. - clear_signal_registry() - - -def test_semantic_similarity_compute_keys(mocker: MockerFixture) -> None: - vector_index = make_vector_index('test_vector_store', EMBEDDINGS) - - embed_mock = mocker.spy(TestEmbedding, 'compute') - - signal = SemanticSimilaritySignal(query='hello', embedding=TestEmbedding.name) - scores = list(signal.vector_compute([('1',), ('2',), ('3',)], vector_index)) - - # Embeddings should be called only 1 time for the search. - assert embed_mock.call_count == 1 - - assert scores == [ - [lilac_span(0, 0, {'score': 1})], - [lilac_span(0, 0, {'score': 0.9})], - [lilac_span(0, 0, {'score': 0})], - ] - - -def test_semantic_similarity_compute_data(mocker: MockerFixture) -> None: - embed_mock = mocker.spy(TestEmbedding, 'compute') - - signal = SemanticSimilaritySignal(query='hello', embedding=TestEmbedding.name) - # Compute over the text. - scores = list(signal.compute(STR_EMBEDDINGS.keys())) - - # Embeddings should be called only 2 times, once for the search, once for the query itself. - assert embed_mock.call_count == 2 - - assert scores == [ - [lilac_span(0, 5, {'score': 1})], - [lilac_span(0, 11, {'score': 0.9})], - [lilac_span(0, 3, {'score': 0})], - ] diff --git a/lilac/signals/substring_search_test.py b/lilac/signals/substring_search_test.py deleted file mode 100644 index 6c12eb98d84e698f2eafef81a4c42805a9497c97..0000000000000000000000000000000000000000 --- a/lilac/signals/substring_search_test.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Test the Substring Search signal.""" - -import pytest -from pydantic import ValidationError - -from ..schema import field -from ..splitters.text_splitter_test_utils import text_to_expected_spans -from .substring_search import SubstringSignal - - -def test_substring_fields() -> None: - signal = SubstringSignal(query='test') - assert signal.fields() == field(fields=['string_span']) - - -def test_query_is_required() -> None: - with pytest.raises(ValidationError): - SubstringSignal() - - -def test_compute() -> None: - signal = SubstringSignal(query='test') - - text = 'The word TEST shows up 3 times, teST and test' - spans = list(signal.compute([text])) - - expected_spans = text_to_expected_spans(text, ['TEST', 'teST', 'test']) - assert [expected_spans] == spans diff --git a/lilac/signals/text_statistics_test.py b/lilac/signals/text_statistics_test.py deleted file mode 100644 index 1ac7bacd18ab4b80044aca7bb23bf2d19c8d8590..0000000000000000000000000000000000000000 --- a/lilac/signals/text_statistics_test.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Test the semantic search signal.""" - -from typing import cast - -from pytest import approx - -from ..schema import DataType, Field -from .text_statistics import ( - FRAC_NON_ASCII, - NUM_CHARS, - READABILITY, - TYPE_TOKEN_RATIO, - TextStatisticsSignal, -) - - -def test_text_statistics_fields() -> None: - signal = TextStatisticsSignal() - signal.setup() - assert signal.fields() == Field( - fields={ - NUM_CHARS: Field(dtype=DataType.INT32), - READABILITY: Field(dtype=DataType.FLOAT32), - TYPE_TOKEN_RATIO: Field(dtype=DataType.FLOAT32), - FRAC_NON_ASCII: Field( - dtype=DataType.FLOAT32, - bins=[('Low', None, 0.15), ('Medium', 0.15, 0.3), ('High', 0.3, None)]), - }) - - -def test_text_statistics_compute() -> None: - signal = TextStatisticsSignal() - signal.setup() - - scores = signal.compute(['hello', 'hello world']) - assert list(scores) == [{ - NUM_CHARS: 5, - READABILITY: approx(2.62), - TYPE_TOKEN_RATIO: 0.0, - FRAC_NON_ASCII: 0.0 - }, { - NUM_CHARS: 11, - READABILITY: approx(3.12), - TYPE_TOKEN_RATIO: 1.0, - FRAC_NON_ASCII: 0.0 - }] - - -def test_text_statistics_missing_value() -> None: - signal = TextStatisticsSignal() - signal.setup() - - scores = signal.compute(['hello', cast(str, None), 'everybody']) - - assert list(scores) == [{ - NUM_CHARS: 5, - READABILITY: approx(2.62), - TYPE_TOKEN_RATIO: 0.0, - FRAC_NON_ASCII: 0.0 - }, None, { - NUM_CHARS: 9, - READABILITY: approx(21.46), - TYPE_TOKEN_RATIO: 0.0, - FRAC_NON_ASCII: 0.0 - }] diff --git a/lilac/sources/__pycache__/__init__.cpython-39.pyc b/lilac/sources/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index d6cb3705e7e73299a76ebb83e344d21820a7e5b7..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/csv_source.cpython-39.pyc b/lilac/sources/__pycache__/csv_source.cpython-39.pyc deleted file mode 100644 index 2bbb5865b4ef556ac199404ce224c4e3fa43f816..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/csv_source.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/csv_source_test.cpython-39-pytest-7.4.0.pyc b/lilac/sources/__pycache__/csv_source_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 84b8077e32fb1af83a16af284a904ab096e7984b..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/csv_source_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/default_sources.cpython-39.pyc b/lilac/sources/__pycache__/default_sources.cpython-39.pyc deleted file mode 100644 index b23b38d9a219c112a8518e9baf7641aec28bdc66..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/default_sources.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/duckdb_utils.cpython-39.pyc b/lilac/sources/__pycache__/duckdb_utils.cpython-39.pyc deleted file mode 100644 index 0a63b08fc53d67063140dedd5f3a2914a7e2e748..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/duckdb_utils.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/gmail_source.cpython-39.pyc b/lilac/sources/__pycache__/gmail_source.cpython-39.pyc deleted file mode 100644 index 1a4bf61d90159d101e2acd734d10bc080bdd8d45..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/gmail_source.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/huggingface_source.cpython-39.pyc b/lilac/sources/__pycache__/huggingface_source.cpython-39.pyc deleted file mode 100644 index 3a5f2d0ecafdaf3f37752f6ad04ced393d964c4e..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/huggingface_source.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/huggingface_source_test.cpython-39-pytest-7.4.0.pyc b/lilac/sources/__pycache__/huggingface_source_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index dd51dd8c0509ff6fccbd60c309b0253ee1a94734..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/huggingface_source_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/json_source.cpython-39.pyc b/lilac/sources/__pycache__/json_source.cpython-39.pyc deleted file mode 100644 index 16e13cd875521852cc3feb234857fea3a3055abf..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/json_source.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/json_source_test.cpython-39-pytest-7.4.0.pyc b/lilac/sources/__pycache__/json_source_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 554d6044b2d9d71d8ad9bf76475b8f34899b33cb..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/json_source_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/pandas_source.cpython-39.pyc b/lilac/sources/__pycache__/pandas_source.cpython-39.pyc deleted file mode 100644 index faa3cdc9c4187852ae996e16c257bf34e0c8c05f..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/pandas_source.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/pandas_source_test.cpython-39-pytest-7.4.0.pyc b/lilac/sources/__pycache__/pandas_source_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 208cfeebe366e97edc4c6d049f307ea41311f190..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/pandas_source_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/parquet_source.cpython-39.pyc b/lilac/sources/__pycache__/parquet_source.cpython-39.pyc deleted file mode 100644 index 23942acba104c87c004fb86fac7dfc980581f5b9..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/parquet_source.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/parquet_source_test.cpython-39-pytest-7.4.0.pyc b/lilac/sources/__pycache__/parquet_source_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 89779ce58f2134fae8f8ff9103da7f93a408bddc..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/parquet_source_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/source.cpython-39.pyc b/lilac/sources/__pycache__/source.cpython-39.pyc deleted file mode 100644 index d04a9069e2d189264492e8a9edab5843eece3312..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/source.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/source_registry.cpython-39.pyc b/lilac/sources/__pycache__/source_registry.cpython-39.pyc deleted file mode 100644 index 2c4418b550ce978f09c90ac246d0917f1aadde12..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/source_registry.cpython-39.pyc and /dev/null differ diff --git a/lilac/sources/__pycache__/source_registry_test.cpython-39-pytest-7.4.0.pyc b/lilac/sources/__pycache__/source_registry_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index b9b7288f53bdfa6a64b790f80b0ac72c678d8436..0000000000000000000000000000000000000000 Binary files a/lilac/sources/__pycache__/source_registry_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/sources/csv_source_test.py b/lilac/sources/csv_source_test.py deleted file mode 100644 index 6f94dbc8c8f472166eee3ebc8be86bfd6903eeb3..0000000000000000000000000000000000000000 --- a/lilac/sources/csv_source_test.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Tests for the CSV source.""" -import csv -import os -import pathlib - -from ..schema import schema -from .csv_source import LINE_NUMBER_COLUMN, CSVSource -from .source import SourceSchema - - -def test_csv(tmp_path: pathlib.Path) -> None: - csv_rows = [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}] - - filename = 'test-dataset.csv' - filepath = os.path.join(tmp_path, filename) - with open(filepath, 'w') as f: - writer = csv.DictWriter(f, fieldnames=list(csv_rows[0].keys())) - writer.writeheader() - writer.writerows(csv_rows) - - source = CSVSource(filepaths=[filepath]) - source.setup() - - source_schema = source.source_schema() - assert source_schema == SourceSchema( - fields=schema({ - LINE_NUMBER_COLUMN: 'int64', - 'x': 'int64', - 'y': 'string' - }).fields, num_items=2) - - items = list(source.process()) - - assert items == [{ - LINE_NUMBER_COLUMN: 0, - 'x': 1, - 'y': 'ten' - }, { - LINE_NUMBER_COLUMN: 1, - 'x': 2, - 'y': 'twenty' - }] diff --git a/lilac/sources/huggingface_source_test.py b/lilac/sources/huggingface_source_test.py deleted file mode 100644 index bac1d75d13464f3e838370495c5cf641119b1d23..0000000000000000000000000000000000000000 --- a/lilac/sources/huggingface_source_test.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Tests for the pandas source.""" -import os -import pathlib - -# mypy: disable-error-code="attr-defined" -from datasets import Dataset, Features, Sequence, Value - -from ..schema import schema -from .huggingface_source import HF_SPLIT_COLUMN, HuggingFaceSource -from .source import SourceSchema - - -def test_hf(tmp_path: pathlib.Path) -> None: - dataset = Dataset.from_list([{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}]) - - dataset_name = os.path.join(tmp_path, 'hf-test-dataset') - dataset.save_to_disk(dataset_name) - - source = HuggingFaceSource(dataset_name=dataset_name, load_from_disk=True) - - items = source.process() - source.setup() - - source_schema = source.source_schema() - assert source_schema == SourceSchema( - fields=schema({ - HF_SPLIT_COLUMN: 'string', - 'x': 'int64', - 'y': 'string' - }).fields, num_items=2) - - items = list(source.process()) - - assert items == [{ - HF_SPLIT_COLUMN: 'default', - 'x': 1, - 'y': 'ten' - }, { - HF_SPLIT_COLUMN: 'default', - 'x': 2, - 'y': 'twenty' - }] - - -def test_hf_sequence(tmp_path: pathlib.Path) -> None: - dataset = Dataset.from_list([{ - 'scalar': 1, - 'seq': [1, 0], - 'seq_dict': { - 'x': [1, 2, 3], - 'y': ['four', 'five', 'six'] - } - }, { - 'scalar': 2, - 'seq': [2, 0], - 'seq_dict': { - 'x': [10, 20, 30], - 'y': ['forty', 'fifty', 'sixty'] - } - }], - features=Features({ - 'scalar': Value(dtype='int64'), - 'seq': Sequence(feature=Value(dtype='int64')), - 'seq_dict': Sequence(feature={ - 'x': Value(dtype='int64'), - 'y': Value(dtype='string') - }) - })) - - dataset_name = os.path.join(tmp_path, 'hf-test-dataset') - dataset.save_to_disk(dataset_name) - - source = HuggingFaceSource(dataset_name=dataset_name, load_from_disk=True) - - items = source.process() - source.setup() - - source_schema = source.source_schema() - assert source_schema == SourceSchema( - fields=schema({ - HF_SPLIT_COLUMN: 'string', - 'scalar': 'int64', - 'seq': ['int64'], - 'seq_dict': { - 'x': ['int64'], - 'y': ['string'], - }, - }).fields, - num_items=2) - - items = list(source.process()) - - assert items == [{ - HF_SPLIT_COLUMN: 'default', - 'scalar': 1, - 'seq': [1, 0], - 'seq_dict': { - 'x': [1, 2, 3], - 'y': ['four', 'five', 'six'] - } - }, { - HF_SPLIT_COLUMN: 'default', - 'scalar': 2, - 'seq': [2, 0], - 'seq_dict': { - 'x': [10, 20, 30], - 'y': ['forty', 'fifty', 'sixty'] - } - }] - - -def test_hf_list(tmp_path: pathlib.Path) -> None: - dataset = Dataset.from_list([{ - 'scalar': 1, - 'list': [{ - 'x': 1, - 'y': 'two' - }] - }, { - 'scalar': 2, - 'list': [{ - 'x': 3, - 'y': 'four' - }] - }], - features=Features({ - 'scalar': Value(dtype='int64'), - 'list': [{ - 'x': Value(dtype='int64'), - 'y': Value(dtype='string') - }] - })) - - dataset_name = os.path.join(tmp_path, 'hf-test-dataset') - dataset.save_to_disk(dataset_name) - - source = HuggingFaceSource(dataset_name=dataset_name, load_from_disk=True) - - items = source.process() - source.setup() - - source_schema = source.source_schema() - assert source_schema == SourceSchema( - fields=schema({ - HF_SPLIT_COLUMN: 'string', - 'scalar': 'int64', - 'list': [{ - 'x': 'int64', - 'y': 'string', - }], - }).fields, - num_items=2) - - items = list(source.process()) - - assert items == [{ - HF_SPLIT_COLUMN: 'default', - 'scalar': 1, - 'list': [{ - 'x': 1, - 'y': 'two' - }] - }, { - HF_SPLIT_COLUMN: 'default', - 'scalar': 2, - 'list': [{ - 'x': 3, - 'y': 'four' - }] - }] diff --git a/lilac/sources/json_source_test.py b/lilac/sources/json_source_test.py deleted file mode 100644 index e853117fc7114f2a73617dce2badc8c52c4118fd..0000000000000000000000000000000000000000 --- a/lilac/sources/json_source_test.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Tests for the JSON source.""" -import json -import os -import pathlib - -from ..schema import schema -from .json_source import ROW_ID_COLUMN, JSONSource -from .source import SourceSchema - - -def test_simple_json(tmp_path: pathlib.Path) -> None: - json_records = [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}] - - filename = 'test-dataset.jsonl' - filepath = os.path.join(tmp_path, filename) - with open(filepath, 'w') as f: - f.write(json.dumps(json_records)) - - source = JSONSource(filepaths=[filepath]) - source.setup() - - source_schema = source.source_schema() - assert source_schema == SourceSchema( - fields=schema({ - ROW_ID_COLUMN: 'int64', - 'x': 'int64', - 'y': 'string' - }).fields, num_items=2) - - items = list(source.process()) - - assert items == [{ - ROW_ID_COLUMN: 0, - 'x': 1, - 'y': 'ten' - }, { - ROW_ID_COLUMN: 1, - 'x': 2, - 'y': 'twenty' - }] - - -def test_simple_jsonl(tmp_path: pathlib.Path) -> None: - json_records = [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}] - json_lines = [json.dumps(record) + '\n' for record in json_records] - - filename = 'test-dataset.jsonl' - filepath = os.path.join(tmp_path, filename) - with open(filepath, 'w') as f: - f.writelines(json_lines) - - source = JSONSource(filepaths=[filepath]) - source.setup() - - source_schema = source.source_schema() - - assert source_schema == SourceSchema( - fields=schema({ - ROW_ID_COLUMN: 'int64', - 'x': 'int64', - 'y': 'string' - }).fields, num_items=2) - - items = list(source.process()) - - assert items == [{ - ROW_ID_COLUMN: 0, - 'x': 1, - 'y': 'ten' - }, { - ROW_ID_COLUMN: 1, - 'x': 2, - 'y': 'twenty' - }] diff --git a/lilac/sources/pandas_source_test.py b/lilac/sources/pandas_source_test.py deleted file mode 100644 index 1694c3399b9cc2bd46ec163c6467657a04a671a8..0000000000000000000000000000000000000000 --- a/lilac/sources/pandas_source_test.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Tests for the pandas source.""" - -import pandas as pd - -from ..schema import schema -from .pandas_source import PANDAS_INDEX_COLUMN, PandasSource -from .source import SourceSchema - - -def test_simple_dataframe() -> None: - df = pd.DataFrame.from_records([{ - 'name': 'a', - 'age': 1 - }, { - 'name': 'b', - 'age': 2 - }, { - 'name': 'c', - 'age': 3 - }]) - - source = PandasSource(df) - source.setup() - - source_schema = source.source_schema() - assert source_schema == SourceSchema( - fields=schema({ - PANDAS_INDEX_COLUMN: 'int64', - 'name': 'string', - 'age': 'int64' - }).fields, - num_items=3) - - items = list(source.process()) - - assert items == [{ - PANDAS_INDEX_COLUMN: 0, - 'name': 'a', - 'age': 1 - }, { - PANDAS_INDEX_COLUMN: 1, - 'name': 'b', - 'age': 2 - }, { - PANDAS_INDEX_COLUMN: 2, - 'name': 'c', - 'age': 3 - }] - - -def test_simple_dataframe_with_index() -> None: - df = pd.DataFrame.from_records([{ - 'name': 'a', - 'age': 1 - }, { - 'name': 'b', - 'age': 2 - }, { - 'name': 'c', - 'age': 3 - }], - index=['id1', 'id2', 'id3']) - - source = PandasSource(df) - source.setup() - - source_schema = source.source_schema() - assert source_schema == SourceSchema( - fields=schema({ - PANDAS_INDEX_COLUMN: 'string', - 'name': 'string', - 'age': 'int64' - }).fields, - num_items=3) - - items = list(source.process()) - - # The PANDAS_INDEX_COLUMN aligns with the pandas index. - assert items == [{ - PANDAS_INDEX_COLUMN: 'id1', - 'name': 'a', - 'age': 1 - }, { - PANDAS_INDEX_COLUMN: 'id2', - 'name': 'b', - 'age': 2 - }, { - PANDAS_INDEX_COLUMN: 'id3', - 'name': 'c', - 'age': 3 - }] diff --git a/lilac/sources/parquet_source_test.py b/lilac/sources/parquet_source_test.py deleted file mode 100644 index cd1093121eb85f2d22409ef9b200100fd215f587..0000000000000000000000000000000000000000 --- a/lilac/sources/parquet_source_test.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Tests for the paquet source.""" - -import os -import pathlib - -import pyarrow as pa -import pyarrow.parquet as pq - -from ..schema import schema -from .parquet_source import ParquetSource -from .source import SourceSchema - - -def test_simple_rows(tmp_path: pathlib.Path) -> None: - table = pa.Table.from_pylist([{ - 'name': 'a', - 'age': 1 - }, { - 'name': 'b', - 'age': 2 - }, { - 'name': 'c', - 'age': 3 - }]) - - out_file = os.path.join(tmp_path, 'test.parquet') - pq.write_table(table, out_file) - - source = ParquetSource(filepaths=[out_file]) - source.setup() - source_schema = source.source_schema() - assert source_schema == SourceSchema( - fields=schema({ - 'name': 'string', - 'age': 'int64' - }).fields, num_items=3) - - items = list(source.process()) - assert items == [{'name': 'a', 'age': 1}, {'name': 'b', 'age': 2}, {'name': 'c', 'age': 3}] diff --git a/lilac/sources/source_registry_test.py b/lilac/sources/source_registry_test.py deleted file mode 100644 index 4ef78ea181c8df2b90e285c3ed368f625f16dedf..0000000000000000000000000000000000000000 --- a/lilac/sources/source_registry_test.py +++ /dev/null @@ -1,55 +0,0 @@ -"""A source to compute semantic search for a document.""" -from typing import Iterable, cast - -import pytest -from typing_extensions import override - -from ..schema import Item -from .source import Source, SourceSchema -from .source_registry import clear_source_registry, get_source_cls, register_source, resolve_source - - -class TestSource(Source): - """A test source.""" - name = 'test_source' - - @override - def setup(self) -> None: - pass - - @override - def source_schema(self) -> SourceSchema: - """Return the source schema.""" - return cast(SourceSchema, None) - - @override - def process(self) -> Iterable[Item]: - yield None - - -@pytest.fixture(scope='module', autouse=True) -def setup_teardown() -> Iterable[None]: - # Setup. - register_source(TestSource) - - # Unit test runs. - yield - - # Teardown. - clear_source_registry() - - -def test_get_source_cls() -> None: - """Test getting a source.""" - assert TestSource == get_source_cls('test_source') - - -def test_resolve_source() -> None: - """Test resolving a source.""" - test_source = TestSource() - - # sources pass through. - assert resolve_source(test_source) == test_source - - # Dicts resolve to the base class. - assert resolve_source(test_source.dict()) == test_source diff --git a/lilac/splitters/__pycache__/__init__.cpython-39.pyc b/lilac/splitters/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index a4d606665fa178317670c0463fd3f27d5f04ab10..0000000000000000000000000000000000000000 Binary files a/lilac/splitters/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/lilac/splitters/__pycache__/chunk_splitter.cpython-39.pyc b/lilac/splitters/__pycache__/chunk_splitter.cpython-39.pyc deleted file mode 100644 index 1d273ca2da3ec0eb0a9c30e6b94e7e779493db78..0000000000000000000000000000000000000000 Binary files a/lilac/splitters/__pycache__/chunk_splitter.cpython-39.pyc and /dev/null differ diff --git a/lilac/splitters/__pycache__/chunk_splitter_test.cpython-39-pytest-7.4.0.pyc b/lilac/splitters/__pycache__/chunk_splitter_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index eb9c5926cf69b73666b46a77e52fa0395a164eec..0000000000000000000000000000000000000000 Binary files a/lilac/splitters/__pycache__/chunk_splitter_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/splitters/__pycache__/text_splitter_spacy.cpython-39.pyc b/lilac/splitters/__pycache__/text_splitter_spacy.cpython-39.pyc deleted file mode 100644 index 745d7747462b0ee36e461db1c4f72ed5205bdba7..0000000000000000000000000000000000000000 Binary files a/lilac/splitters/__pycache__/text_splitter_spacy.cpython-39.pyc and /dev/null differ diff --git a/lilac/splitters/__pycache__/text_splitter_spacy_test.cpython-39-pytest-7.4.0.pyc b/lilac/splitters/__pycache__/text_splitter_spacy_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index 7e1d47257a047119ac29e4dae1863f3f3225c5ec..0000000000000000000000000000000000000000 Binary files a/lilac/splitters/__pycache__/text_splitter_spacy_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/splitters/__pycache__/text_splitter_test_utils.cpython-39.pyc b/lilac/splitters/__pycache__/text_splitter_test_utils.cpython-39.pyc deleted file mode 100644 index 94733f5cee642e7326e543a161fd066cd5cabd3d..0000000000000000000000000000000000000000 Binary files a/lilac/splitters/__pycache__/text_splitter_test_utils.cpython-39.pyc and /dev/null differ diff --git a/lilac/splitters/__pycache__/text_splitter_test_utils_test.cpython-39-pytest-7.4.0.pyc b/lilac/splitters/__pycache__/text_splitter_test_utils_test.cpython-39-pytest-7.4.0.pyc deleted file mode 100644 index acde55f865da6473656d7c202a94e5932c70560a..0000000000000000000000000000000000000000 Binary files a/lilac/splitters/__pycache__/text_splitter_test_utils_test.cpython-39-pytest-7.4.0.pyc and /dev/null differ diff --git a/lilac/splitters/chunk_splitter_test.py b/lilac/splitters/chunk_splitter_test.py deleted file mode 100644 index 4032e84eb8375e31452721699e751255186790a0..0000000000000000000000000000000000000000 --- a/lilac/splitters/chunk_splitter_test.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Tests the chunk splitter.""" - -from .chunk_splitter import ChunkSplitter -from .text_splitter_test_utils import spans_to_text, text_to_expected_spans - - -def test_paragraphs_no_overlap() -> None: - signal = ChunkSplitter(chunk_size=12, chunk_overlap=0) - text = 'Hello.\n\nThis will get split.\n\nThe sentence\n\nA.\n\nB.\n\nC.' - split_items = list(signal.compute([text])) - - # "This will get split" should split in 2 chunks, and "A.\n\nB.\n\nC." should be 1 chunk. - expected_spans = text_to_expected_spans( - text, ['Hello.', 'This will', 'get split.', 'The sentence', 'A.\n\nB.\n\nC.']) - assert split_items == [expected_spans] - - -def test_single_world_is_too_long_no_overlap() -> None: - signal = ChunkSplitter(chunk_size=6, chunk_overlap=0) - text = 'ThisIsASingleWordThatIsTooLong' - split_items = list(signal.compute([text])) - - expected_spans = text_to_expected_spans(text, ['ThisIs', 'ASingl', 'eWordT', 'hatIsT', 'ooLong']) - assert split_items == [expected_spans] - - -def test_newlines_with_overlap() -> None: - signal = ChunkSplitter(chunk_size=12, chunk_overlap=5) - text = 'Hello.\n\nWorld.\n\nThis will get split.' - spans = list(signal.compute([text]))[0] - - expected_chunks = ['Hello.', 'World.', 'This will', 'will get', 'get split.'] - assert spans_to_text(text, spans) == expected_chunks - - -def test_serialization() -> None: - signal = ChunkSplitter(chunk_size=12, chunk_overlap=5) - assert signal.dict() == { - 'signal_name': 'chunk', - 'chunk_size': 12, - 'chunk_overlap': 5, - 'separators': ['```', '\n\n', '\n', ' ', ''] - } - - -def test_split_code() -> None: - signal = ChunkSplitter(chunk_size=60, chunk_overlap=0) - text = """ - We expected the entire code to be one span. - - ```python - def hello(): - echo('hello') - ``` - - This is the rest of the text. - """ - spans = list(signal.compute([text]))[0] - expected_chunks = [ - """ - We expected the entire code to be one span. - - """, - """```python - def hello(): - echo('hello') - ```""", - """ - - This is the rest of the text. - """, - ] - assert spans_to_text(text, spans) == expected_chunks diff --git a/lilac/splitters/text_splitter_spacy_test.py b/lilac/splitters/text_splitter_spacy_test.py deleted file mode 100644 index 4d00c754df489fd82dc4cdae2ca065fc26569606..0000000000000000000000000000000000000000 --- a/lilac/splitters/text_splitter_spacy_test.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Tests the spacy sentence splitter.""" - -from typing import cast - -from .text_splitter_spacy import SentenceSplitterSpacy -from .text_splitter_test_utils import text_to_expected_spans - - -def test_splitter_spacy() -> None: - signal = SentenceSplitterSpacy() - signal.setup() - text = 'Hello. This is a test. Final sentence.' - - # Compute over the text. - split_items = list(signal.compute([text])) - - expected_spans = text_to_expected_spans(text, ['Hello.', 'This is a test.', 'Final sentence.']) - assert split_items == [expected_spans] - - -def test_spacy_key() -> None: - signal = SentenceSplitterSpacy() - assert signal.key() == 'sentences' - - -def test_spacy_non_en_key() -> None: - signal = SentenceSplitterSpacy(language='es') - assert signal.key() == 'sentences(language=es)' - - -def test_splitter_spacy_float() -> None: - signal = SentenceSplitterSpacy() - signal.setup() - text = 1.2 - - # Compute over the input, make sure it doesn't crash when we pass a non-string value which can - # happen accidentally in user data. - split_items = list(signal.compute([cast(str, text)])) - - assert split_items == [None] diff --git a/lilac/splitters/text_splitter_test_utils_test.py b/lilac/splitters/text_splitter_test_utils_test.py deleted file mode 100644 index eb18bf1a4e09f51518487eafa7fc20c5f7a4d176..0000000000000000000000000000000000000000 --- a/lilac/splitters/text_splitter_test_utils_test.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Test the text splitter utils.""" - -from ..schema import lilac_span -from .text_splitter_test_utils import text_to_expected_spans - - -def test_text_to_expected_spans() -> None: - """Tests the sentences_to_expected_spans function.""" - text = 'Hello. Hello. Final sentence.' - sentences = ['Hello.', 'Hello.', 'Final sentence.'] - assert text_to_expected_spans( - text, sentences) == [lilac_span(0, 6), lilac_span(7, 13), - lilac_span(14, 29)]