nsthorat commited on
Commit
55dc3dd
1 Parent(s): 2debc48
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +4 -1
  2. requirements.txt +40 -39
  3. src/__pycache__/conftest.cpython-39-pytest-7.4.0.pyc +0 -0
  4. src/__pycache__/data_loader.cpython-39.pyc +0 -0
  5. src/__pycache__/data_loader_test.cpython-39-pytest-7.4.0.pyc +0 -0
  6. src/__pycache__/router_concept.cpython-39.pyc +0 -0
  7. src/__pycache__/router_dataset.cpython-39.pyc +0 -0
  8. src/__pycache__/router_signal.cpython-39.pyc +0 -0
  9. src/__pycache__/schema.cpython-39.pyc +0 -0
  10. src/__pycache__/schema_test.cpython-39-pytest-7.4.0.pyc +0 -0
  11. src/__pycache__/server.cpython-39.pyc +0 -0
  12. src/__pycache__/server_concept_test.cpython-39-pytest-7.3.1.pyc +0 -0
  13. src/__pycache__/server_concept_test.cpython-39-pytest-7.4.0.pyc +0 -0
  14. src/__pycache__/server_test.cpython-39-pytest-7.4.0.pyc +0 -0
  15. src/__pycache__/tasks.cpython-39.pyc +0 -0
  16. src/__pycache__/test_utils.cpython-39-pytest-7.4.0.pyc +0 -0
  17. src/concepts/__pycache__/concept.cpython-39.pyc +0 -0
  18. src/concepts/__pycache__/concept_test.cpython-39-pytest-7.4.0.pyc +0 -0
  19. src/concepts/__pycache__/db_concept.cpython-39.pyc +0 -0
  20. src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.3.1.pyc +0 -0
  21. src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.4.0.pyc +0 -0
  22. src/concepts/concept.py +167 -50
  23. src/concepts/db_concept.py +80 -12
  24. src/concepts/db_concept_test.py +33 -27
  25. src/data/__pycache__/dataset.cpython-39.pyc +0 -0
  26. src/data/__pycache__/dataset_compute_signal_chain_test.cpython-39-pytest-7.4.0.pyc +0 -0
  27. src/data/__pycache__/dataset_compute_signal_test.cpython-39-pytest-7.4.0.pyc +0 -0
  28. src/data/__pycache__/dataset_duckdb.cpython-39.pyc +0 -0
  29. src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.3.1.pyc +0 -0
  30. src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.4.0.pyc +0 -0
  31. src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.3.1.pyc +0 -0
  32. src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.4.0.pyc +0 -0
  33. src/data/__pycache__/dataset_select_rows_schema_test.cpython-39-pytest-7.4.0.pyc +0 -0
  34. src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.3.1.pyc +0 -0
  35. src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.4.0.pyc +0 -0
  36. src/data/__pycache__/dataset_select_rows_sort_test.cpython-39-pytest-7.4.0.pyc +0 -0
  37. src/data/__pycache__/dataset_select_rows_udf_test.cpython-39-pytest-7.4.0.pyc +0 -0
  38. src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.3.1.pyc +0 -0
  39. src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.4.0.pyc +0 -0
  40. src/data/__pycache__/dataset_test.cpython-39-pytest-7.4.0.pyc +0 -0
  41. src/data/__pycache__/dataset_utils.cpython-39.pyc +0 -0
  42. src/data/__pycache__/dataset_utils_test.cpython-39-pytest-7.4.0.pyc +0 -0
  43. src/data/__pycache__/duckdb_utils.cpython-39.pyc +0 -0
  44. src/data/dataset.py +1 -1
  45. src/data/dataset_duckdb.py +59 -33
  46. src/data/dataset_select_groups_test.py +9 -5
  47. src/data/dataset_select_rows_filter_test.py +88 -0
  48. src/data/dataset_select_rows_search_test.py +1 -9
  49. src/data/dataset_stats_test.py +5 -2
  50. src/data/dataset_utils.py +5 -1
Dockerfile CHANGED
@@ -13,7 +13,7 @@ COPY requirements.txt .
13
  RUN pip install --no-cache-dir -r requirements.txt
14
 
15
  # Copy the data to /data, the HF persistent storage. We do this after pip install to avoid
16
- # re-installing dependencies if the data changes.
17
  WORKDIR /
18
  COPY /data /data
19
  WORKDIR /server
@@ -27,4 +27,7 @@ COPY /web/blueprint/build ./web/blueprint/build
27
  # Copy python files.
28
  COPY /src ./src/
29
 
 
 
 
30
  CMD ["uvicorn", "src.server:app", "--host", "0.0.0.0", "--port", "5432"]
 
13
  RUN pip install --no-cache-dir -r requirements.txt
14
 
15
  # Copy the data to /data, the HF persistent storage. We do this after pip install to avoid
16
+ # re-installing dependencies if the data changes, which is likely more often.
17
  WORKDIR /
18
  COPY /data /data
19
  WORKDIR /server
 
27
  # Copy python files.
28
  COPY /src ./src/
29
 
30
+ # Copy the entrypoint file.
31
+ COPY docker_entrypoint.sh .
32
+
33
  CMD ["uvicorn", "src.server:app", "--host", "0.0.0.0", "--port", "5432"]
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  aiohttp==3.8.4 ; python_version >= "3.9" and python_version < "3.10"
2
  aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.10"
3
- anyio==3.7.0 ; python_version >= "3.9" and python_version < "3.10"
4
  async-timeout==4.0.2 ; python_version >= "3.9" and python_version < "3.10"
5
  attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.10"
6
  blis==0.7.9 ; python_version >= "3.9" and python_version < "3.10"
@@ -12,43 +12,43 @@ click==8.1.3 ; python_version >= "3.9" and python_version < "3.10"
12
  cloudpickle==2.2.1 ; python_version >= "3.9" and python_version < "3.10"
13
  cohere==3.10.0 ; python_version >= "3.9" and python_version < "3.10"
14
  colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.10" and (platform_system == "Windows" or sys_platform == "win32")
15
- confection==0.0.4 ; python_version >= "3.9" and python_version < "3.10"
16
  cymem==2.0.7 ; python_version >= "3.9" and python_version < "3.10"
17
  cytoolz==0.12.1 ; python_version >= "3.9" and python_version < "3.10"
18
- dask==2023.5.1 ; python_version >= "3.9" and python_version < "3.10"
19
- datasets==2.12.0 ; python_version >= "3.9" and python_version < "3.10"
20
  decorator==5.1.1 ; python_version >= "3.9" and python_version < "3.10"
21
  dill==0.3.6 ; python_version >= "3.9" and python_version < "3.10"
22
- distributed==2023.5.1 ; python_version >= "3.9" and python_version < "3.10"
23
- duckdb==0.8.0 ; python_version >= "3.9" and python_version < "3.10"
24
  email-reply-parser==0.5.12 ; python_version >= "3.9" and python_version < "3.10"
25
- exceptiongroup==1.1.1 ; python_version >= "3.9" and python_version < "3.10"
26
- fastapi==0.95.2 ; python_version >= "3.9" and python_version < "3.10"
27
- filelock==3.12.0 ; python_version >= "3.9" and python_version < "3.10"
28
  floret==0.10.3 ; python_version >= "3.9" and python_version < "3.10"
29
  frozenlist==1.3.3 ; python_version >= "3.9" and python_version < "3.10"
30
- fsspec==2023.5.0 ; python_version >= "3.9" and python_version < "3.10"
31
- fsspec[http]==2023.5.0 ; python_version >= "3.9" and python_version < "3.10"
32
- gcsfs==2023.5.0 ; python_version >= "3.9" and python_version < "3.10"
33
- google-api-core==2.11.0 ; python_version >= "3.9" and python_version < "3.10"
34
- google-api-python-client==2.88.0 ; python_version >= "3.9" and python_version < "3.10"
35
  google-auth-httplib2==0.1.0 ; python_version >= "3.9" and python_version < "3.10"
36
  google-auth-oauthlib==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
37
- google-auth==2.19.1 ; python_version >= "3.9" and python_version < "3.10"
38
- google-cloud-core==2.3.2 ; python_version >= "3.9" and python_version < "3.10"
39
- google-cloud-storage==2.9.0 ; python_version >= "3.9" and python_version < "3.10"
40
  google-crc32c==1.5.0 ; python_version >= "3.9" and python_version < "3.10"
41
  google-resumable-media==2.5.0 ; python_version >= "3.9" and python_version < "3.10"
42
- googleapis-common-protos==1.59.0 ; python_version >= "3.9" and python_version < "3.10"
43
  h11==0.14.0 ; python_version >= "3.9" and python_version < "3.10"
44
  httplib2==0.22.0 ; python_version >= "3.9" and python_version < "3.10"
45
  httptools==0.5.0 ; python_version >= "3.9" and python_version < "3.10"
46
  huggingface-hub==0.15.1 ; python_version >= "3.9" and python_version < "3.10"
47
  idna==3.4 ; python_version >= "3.9" and python_version < "3.10"
48
- importlib-metadata==6.6.0 ; python_version >= "3.9" and python_version < "3.10"
49
- jellyfish==0.11.2 ; python_version >= "3.9" and python_version < "3.10"
50
  jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.10"
51
- joblib==1.2.0 ; python_version >= "3.9" and python_version < "3.10"
52
  langcodes==3.3.0 ; python_version >= "3.9" and python_version < "3.10"
53
  locket==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
54
  markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.10"
@@ -59,24 +59,24 @@ multiprocess==0.70.14 ; python_version >= "3.9" and python_version < "3.10"
59
  murmurhash==1.0.9 ; python_version >= "3.9" and python_version < "3.10"
60
  networkx==3.1 ; python_version >= "3.9" and python_version < "3.10"
61
  nltk==3.8.1 ; python_version >= "3.9" and python_version < "3.10"
62
- numpy==1.24.3 ; python_version >= "3.9" and python_version < "3.10"
63
  oauthlib==3.2.2 ; python_version >= "3.9" and python_version < "3.10"
64
  openai-function-call==0.0.5 ; python_version >= "3.9" and python_version < "3.10"
65
  openai==0.27.8 ; python_version >= "3.9" and python_version < "3.10"
66
- orjson==3.9.0 ; python_version >= "3.9" and python_version < "3.10"
67
  packaging==23.1 ; python_version >= "3.9" and python_version < "3.10"
68
- pandas==2.0.2 ; python_version >= "3.9" and python_version < "3.10"
69
  partd==1.4.0 ; python_version >= "3.9" and python_version < "3.10"
70
- pathy==0.10.1 ; python_version >= "3.9" and python_version < "3.10"
71
  pillow==9.5.0 ; python_version >= "3.9" and python_version < "3.10"
72
  preshed==3.0.8 ; python_version >= "3.9" and python_version < "3.10"
73
- protobuf==4.23.2 ; python_version >= "3.9" and python_version < "3.10"
74
  psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.10"
75
  pyarrow==9.0.0 ; python_version >= "3.9" and python_version < "3.10"
76
  pyasn1-modules==0.3.0 ; python_version >= "3.9" and python_version < "3.10"
77
  pyasn1==0.5.0 ; python_version >= "3.9" and python_version < "3.10"
78
- pydantic==1.10.9 ; python_version >= "3.9" and python_version < "3.10"
79
- pyparsing==3.0.9 ; python_version >= "3.9" and python_version < "3.10"
80
  pyphen==0.14.0 ; python_version >= "3.9" and python_version < "3.10"
81
  python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "3.10"
82
  python-dotenv==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
@@ -85,10 +85,10 @@ pyyaml==6.0 ; python_version >= "3.9" and python_version < "3.10"
85
  regex==2023.6.3 ; python_version >= "3.9" and python_version < "3.10"
86
  requests-oauthlib==1.3.1 ; python_version >= "3.9" and python_version < "3.10"
87
  requests==2.31.0 ; python_version >= "3.9" and python_version < "3.10"
88
- responses==0.18.0 ; python_version >= "3.9" and python_version < "3.10"
89
  rsa==4.9 ; python_version >= "3.9" and python_version < "3.10"
90
- scikit-learn==1.2.2 ; python_version >= "3.9" and python_version < "3.10"
91
- scipy==1.10.1 ; python_version >= "3.9" and python_version < "3.10"
 
92
  sentence-transformers==2.2.2 ; python_version >= "3.9" and python_version < "3.10"
93
  sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.10"
94
  setuptools==65.7.0 ; python_version >= "3.9" and python_version < "3.10"
@@ -98,11 +98,12 @@ sniffio==1.3.0 ; python_version >= "3.9" and python_version < "3.10"
98
  sortedcontainers==2.4.0 ; python_version >= "3.9" and python_version < "3.10"
99
  spacy-legacy==3.0.12 ; python_version >= "3.9" and python_version < "3.10"
100
  spacy-loggers==1.0.4 ; python_version >= "3.9" and python_version < "3.10"
101
- spacy==3.5.3 ; python_version >= "3.9" and python_version < "3.10"
102
  srsly==2.4.6 ; python_version >= "3.9" and python_version < "3.10"
103
  starlette==0.27.0 ; python_version >= "3.9" and python_version < "3.10"
104
  sympy==1.12 ; python_version >= "3.9" and python_version < "3.10"
105
- tblib==1.7.0 ; python_version >= "3.9" and python_version < "3.10"
 
106
  textacy==0.13.0 ; python_version >= "3.9" and python_version < "3.10"
107
  thinc==8.1.10 ; python_version >= "3.9" and python_version < "3.10"
108
  threadpoolctl==3.1.0 ; python_version >= "3.9" and python_version < "3.10"
@@ -112,16 +113,16 @@ torch==2.0.1 ; python_version >= "3.9" and python_version < "3.10"
112
  torchvision==0.15.2 ; python_version >= "3.9" and python_version < "3.10"
113
  tornado==6.3.2 ; python_version >= "3.9" and python_version < "3.10"
114
  tqdm==4.65.0 ; python_version >= "3.9" and python_version < "3.10"
115
- transformers==4.29.2 ; python_version >= "3.9" and python_version < "3.10"
116
- typer==0.7.0 ; python_version >= "3.9" and python_version < "3.10"
117
- types-psutil==5.9.5.13 ; python_version >= "3.9" and python_version < "3.10"
118
- typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "3.10"
119
  tzdata==2023.3 ; python_version >= "3.9" and python_version < "3.10"
120
  uritemplate==4.1.1 ; python_version >= "3.9" and python_version < "3.10"
121
  urllib3==1.26.16 ; python_version >= "3.9" and python_version < "3.10"
122
- uvicorn[standard]==0.20.0 ; python_version >= "3.9" and python_version < "3.10"
123
  uvloop==0.17.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "3.10"
124
- wasabi==1.1.1 ; python_version >= "3.9" and python_version < "3.10"
125
  watchfiles==0.19.0 ; python_version >= "3.9" and python_version < "3.10"
126
  websockets==11.0.3 ; python_version >= "3.9" and python_version < "3.10"
127
  xxhash==3.2.0 ; python_version >= "3.9" and python_version < "3.10"
 
1
  aiohttp==3.8.4 ; python_version >= "3.9" and python_version < "3.10"
2
  aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.10"
3
+ anyio==3.7.1 ; python_version >= "3.9" and python_version < "3.10"
4
  async-timeout==4.0.2 ; python_version >= "3.9" and python_version < "3.10"
5
  attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.10"
6
  blis==0.7.9 ; python_version >= "3.9" and python_version < "3.10"
 
12
  cloudpickle==2.2.1 ; python_version >= "3.9" and python_version < "3.10"
13
  cohere==3.10.0 ; python_version >= "3.9" and python_version < "3.10"
14
  colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.10" and (platform_system == "Windows" or sys_platform == "win32")
15
+ confection==0.1.0 ; python_version >= "3.9" and python_version < "3.10"
16
  cymem==2.0.7 ; python_version >= "3.9" and python_version < "3.10"
17
  cytoolz==0.12.1 ; python_version >= "3.9" and python_version < "3.10"
18
+ dask==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
19
+ datasets==2.13.1 ; python_version >= "3.9" and python_version < "3.10"
20
  decorator==5.1.1 ; python_version >= "3.9" and python_version < "3.10"
21
  dill==0.3.6 ; python_version >= "3.9" and python_version < "3.10"
22
+ distributed==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
23
+ duckdb==0.8.1 ; python_version >= "3.9" and python_version < "3.10"
24
  email-reply-parser==0.5.12 ; python_version >= "3.9" and python_version < "3.10"
25
+ exceptiongroup==1.1.2 ; python_version >= "3.9" and python_version < "3.10"
26
+ fastapi==0.98.0 ; python_version >= "3.9" and python_version < "3.10"
27
+ filelock==3.12.2 ; python_version >= "3.9" and python_version < "3.10"
28
  floret==0.10.3 ; python_version >= "3.9" and python_version < "3.10"
29
  frozenlist==1.3.3 ; python_version >= "3.9" and python_version < "3.10"
30
+ fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.10"
31
+ fsspec[http]==2023.6.0 ; python_version >= "3.9" and python_version < "3.10"
32
+ gcsfs==2023.6.0 ; python_version >= "3.9" and python_version < "3.10"
33
+ google-api-core==2.11.1 ; python_version >= "3.9" and python_version < "3.10"
34
+ google-api-python-client==2.92.0 ; python_version >= "3.9" and python_version < "3.10"
35
  google-auth-httplib2==0.1.0 ; python_version >= "3.9" and python_version < "3.10"
36
  google-auth-oauthlib==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
37
+ google-auth==2.21.0 ; python_version >= "3.9" and python_version < "3.10"
38
+ google-cloud-core==2.3.3 ; python_version >= "3.9" and python_version < "3.10"
39
+ google-cloud-storage==2.10.0 ; python_version >= "3.9" and python_version < "3.10"
40
  google-crc32c==1.5.0 ; python_version >= "3.9" and python_version < "3.10"
41
  google-resumable-media==2.5.0 ; python_version >= "3.9" and python_version < "3.10"
42
+ googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "3.10"
43
  h11==0.14.0 ; python_version >= "3.9" and python_version < "3.10"
44
  httplib2==0.22.0 ; python_version >= "3.9" and python_version < "3.10"
45
  httptools==0.5.0 ; python_version >= "3.9" and python_version < "3.10"
46
  huggingface-hub==0.15.1 ; python_version >= "3.9" and python_version < "3.10"
47
  idna==3.4 ; python_version >= "3.9" and python_version < "3.10"
48
+ importlib-metadata==6.7.0 ; python_version >= "3.9" and python_version < "3.10"
49
+ jellyfish==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
50
  jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.10"
51
+ joblib==1.3.1 ; python_version >= "3.9" and python_version < "3.10"
52
  langcodes==3.3.0 ; python_version >= "3.9" and python_version < "3.10"
53
  locket==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
54
  markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.10"
 
59
  murmurhash==1.0.9 ; python_version >= "3.9" and python_version < "3.10"
60
  networkx==3.1 ; python_version >= "3.9" and python_version < "3.10"
61
  nltk==3.8.1 ; python_version >= "3.9" and python_version < "3.10"
62
+ numpy==1.25.0 ; python_version >= "3.9" and python_version < "3.10"
63
  oauthlib==3.2.2 ; python_version >= "3.9" and python_version < "3.10"
64
  openai-function-call==0.0.5 ; python_version >= "3.9" and python_version < "3.10"
65
  openai==0.27.8 ; python_version >= "3.9" and python_version < "3.10"
66
+ orjson==3.9.1 ; python_version >= "3.9" and python_version < "3.10"
67
  packaging==23.1 ; python_version >= "3.9" and python_version < "3.10"
68
+ pandas==2.0.3 ; python_version >= "3.9" and python_version < "3.10"
69
  partd==1.4.0 ; python_version >= "3.9" and python_version < "3.10"
70
+ pathy==0.10.2 ; python_version >= "3.9" and python_version < "3.10"
71
  pillow==9.5.0 ; python_version >= "3.9" and python_version < "3.10"
72
  preshed==3.0.8 ; python_version >= "3.9" and python_version < "3.10"
73
+ protobuf==4.23.3 ; python_version >= "3.9" and python_version < "3.10"
74
  psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.10"
75
  pyarrow==9.0.0 ; python_version >= "3.9" and python_version < "3.10"
76
  pyasn1-modules==0.3.0 ; python_version >= "3.9" and python_version < "3.10"
77
  pyasn1==0.5.0 ; python_version >= "3.9" and python_version < "3.10"
78
+ pydantic==1.10.11 ; python_version >= "3.9" and python_version < "3.10"
79
+ pyparsing==3.1.0 ; python_version >= "3.9" and python_version < "3.10"
80
  pyphen==0.14.0 ; python_version >= "3.9" and python_version < "3.10"
81
  python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "3.10"
82
  python-dotenv==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
 
85
  regex==2023.6.3 ; python_version >= "3.9" and python_version < "3.10"
86
  requests-oauthlib==1.3.1 ; python_version >= "3.9" and python_version < "3.10"
87
  requests==2.31.0 ; python_version >= "3.9" and python_version < "3.10"
 
88
  rsa==4.9 ; python_version >= "3.9" and python_version < "3.10"
89
+ safetensors==0.3.1 ; python_version >= "3.9" and python_version < "3.10"
90
+ scikit-learn==1.3.0 ; python_version >= "3.9" and python_version < "3.10"
91
+ scipy==1.11.1 ; python_version >= "3.9" and python_version < "3.10"
92
  sentence-transformers==2.2.2 ; python_version >= "3.9" and python_version < "3.10"
93
  sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.10"
94
  setuptools==65.7.0 ; python_version >= "3.9" and python_version < "3.10"
 
98
  sortedcontainers==2.4.0 ; python_version >= "3.9" and python_version < "3.10"
99
  spacy-legacy==3.0.12 ; python_version >= "3.9" and python_version < "3.10"
100
  spacy-loggers==1.0.4 ; python_version >= "3.9" and python_version < "3.10"
101
+ spacy==3.5.4 ; python_version >= "3.9" and python_version < "3.10"
102
  srsly==2.4.6 ; python_version >= "3.9" and python_version < "3.10"
103
  starlette==0.27.0 ; python_version >= "3.9" and python_version < "3.10"
104
  sympy==1.12 ; python_version >= "3.9" and python_version < "3.10"
105
+ tblib==2.0.0 ; python_version >= "3.9" and python_version < "3.10"
106
+ tenacity==8.2.2 ; python_version >= "3.9" and python_version < "3.10"
107
  textacy==0.13.0 ; python_version >= "3.9" and python_version < "3.10"
108
  thinc==8.1.10 ; python_version >= "3.9" and python_version < "3.10"
109
  threadpoolctl==3.1.0 ; python_version >= "3.9" and python_version < "3.10"
 
113
  torchvision==0.15.2 ; python_version >= "3.9" and python_version < "3.10"
114
  tornado==6.3.2 ; python_version >= "3.9" and python_version < "3.10"
115
  tqdm==4.65.0 ; python_version >= "3.9" and python_version < "3.10"
116
+ transformers==4.30.2 ; python_version >= "3.9" and python_version < "3.10"
117
+ typer==0.9.0 ; python_version >= "3.9" and python_version < "3.10"
118
+ types-psutil==5.9.5.15 ; python_version >= "3.9" and python_version < "3.10"
119
+ typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.10"
120
  tzdata==2023.3 ; python_version >= "3.9" and python_version < "3.10"
121
  uritemplate==4.1.1 ; python_version >= "3.9" and python_version < "3.10"
122
  urllib3==1.26.16 ; python_version >= "3.9" and python_version < "3.10"
123
+ uvicorn[standard]==0.22.0 ; python_version >= "3.9" and python_version < "3.10"
124
  uvloop==0.17.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "3.10"
125
+ wasabi==1.1.2 ; python_version >= "3.9" and python_version < "3.10"
126
  watchfiles==0.19.0 ; python_version >= "3.9" and python_version < "3.10"
127
  websockets==11.0.3 ; python_version >= "3.9" and python_version < "3.10"
128
  xxhash==3.2.0 ; python_version >= "3.9" and python_version < "3.10"
src/__pycache__/conftest.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (1.41 kB). View file
 
src/__pycache__/data_loader.cpython-39.pyc CHANGED
Binary files a/src/__pycache__/data_loader.cpython-39.pyc and b/src/__pycache__/data_loader.cpython-39.pyc differ
 
src/__pycache__/data_loader_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (4.35 kB). View file
 
src/__pycache__/router_concept.cpython-39.pyc CHANGED
Binary files a/src/__pycache__/router_concept.cpython-39.pyc and b/src/__pycache__/router_concept.cpython-39.pyc differ
 
src/__pycache__/router_dataset.cpython-39.pyc CHANGED
Binary files a/src/__pycache__/router_dataset.cpython-39.pyc and b/src/__pycache__/router_dataset.cpython-39.pyc differ
 
src/__pycache__/router_signal.cpython-39.pyc CHANGED
Binary files a/src/__pycache__/router_signal.cpython-39.pyc and b/src/__pycache__/router_signal.cpython-39.pyc differ
 
src/__pycache__/schema.cpython-39.pyc CHANGED
Binary files a/src/__pycache__/schema.cpython-39.pyc and b/src/__pycache__/schema.cpython-39.pyc differ
 
src/__pycache__/schema_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (9.4 kB). View file
 
src/__pycache__/server.cpython-39.pyc CHANGED
Binary files a/src/__pycache__/server.cpython-39.pyc and b/src/__pycache__/server.cpython-39.pyc differ
 
src/__pycache__/server_concept_test.cpython-39-pytest-7.3.1.pyc CHANGED
Binary files a/src/__pycache__/server_concept_test.cpython-39-pytest-7.3.1.pyc and b/src/__pycache__/server_concept_test.cpython-39-pytest-7.3.1.pyc differ
 
src/__pycache__/server_concept_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (19.2 kB). View file
 
src/__pycache__/server_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (12.5 kB). View file
 
src/__pycache__/tasks.cpython-39.pyc CHANGED
Binary files a/src/__pycache__/tasks.cpython-39.pyc and b/src/__pycache__/tasks.cpython-39.pyc differ
 
src/__pycache__/test_utils.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (1.1 kB). View file
 
src/concepts/__pycache__/concept.cpython-39.pyc CHANGED
Binary files a/src/concepts/__pycache__/concept.cpython-39.pyc and b/src/concepts/__pycache__/concept.cpython-39.pyc differ
 
src/concepts/__pycache__/concept_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (3.74 kB). View file
 
src/concepts/__pycache__/db_concept.cpython-39.pyc CHANGED
Binary files a/src/concepts/__pycache__/db_concept.cpython-39.pyc and b/src/concepts/__pycache__/db_concept.cpython-39.pyc differ
 
src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.3.1.pyc CHANGED
Binary files a/src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.3.1.pyc and b/src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.3.1.pyc differ
 
src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (27.7 kB). View file
 
src/concepts/concept.py CHANGED
@@ -1,23 +1,37 @@
1
  """Defines the concept and the concept models."""
 
2
  import random
3
- from typing import Iterable, Literal, Optional, Union
 
4
 
5
  import numpy as np
 
6
  from pydantic import BaseModel, validator
 
 
7
  from sklearn.exceptions import NotFittedError
8
  from sklearn.linear_model import LogisticRegression
 
 
9
 
10
  from ..db_manager import get_dataset
11
  from ..embeddings.embedding import get_embed_fn
12
  from ..schema import Path, RichData, SignalInputType, normalize_path
13
- from ..signals.signal import TextEmbeddingSignal, get_signal_cls
14
  from ..utils import DebugTimer
15
 
16
  LOCAL_CONCEPT_NAMESPACE = 'local'
17
 
18
  # Number of randomly sampled negative examples to use for training. This is used to obtain a more
19
  # balanced model that works with a specific dataset.
20
- DEFAULT_NUM_NEG_EXAMPLES = 300
 
 
 
 
 
 
 
21
 
22
 
23
  class ConceptColumnInfo(BaseModel):
@@ -29,6 +43,13 @@ class ConceptColumnInfo(BaseModel):
29
  # Path holding the text to use for negative examples.
30
  path: Path
31
 
 
 
 
 
 
 
 
32
  num_negative_examples = DEFAULT_NUM_NEG_EXAMPLES
33
 
34
 
@@ -71,7 +92,7 @@ class Example(ExampleIn):
71
  class Concept(BaseModel):
72
  """A concept is a collection of examples."""
73
  # The namespace of the concept.
74
- namespace: str = LOCAL_CONCEPT_NAMESPACE
75
  # The name of the concept.
76
  concept_name: str
77
  # The type of the data format that this concept represents.
@@ -79,6 +100,8 @@ class Concept(BaseModel):
79
  data: dict[str, Example]
80
  version: int = 0
81
 
 
 
82
  def drafts(self) -> list[DraftId]:
83
  """Gets all the drafts for the concept."""
84
  drafts: set[DraftId] = set([DRAFT_MAIN]) # Always return the main draft.
@@ -88,39 +111,141 @@ class Concept(BaseModel):
88
  return list(sorted(drafts))
89
 
90
 
91
- class LogisticEmbeddingModel(BaseModel):
92
- """A model that uses logistic regression with embeddings."""
 
 
 
 
 
93
 
94
- class Config:
95
- arbitrary_types_allowed = True
96
- underscore_attrs_are_private = True
97
 
98
- version: int = -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- # The following fields are excluded from JSON serialization, but still pickleable.
101
- # See `notebooks/Toxicity.ipynb` for an example of training a concept model.
102
- _model: LogisticRegression = LogisticRegression(
103
- class_weight=None, C=30, tol=1e-5, warm_start=True, max_iter=1_000, n_jobs=-1)
 
 
 
 
 
 
 
 
104
 
105
  def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
106
  """Get the scores for the provided embeddings."""
107
  try:
108
- return self._model.predict_proba(embeddings)[:, 1]
 
 
 
109
  except NotFittedError:
110
  return np.random.rand(len(embeddings))
111
 
112
- def fit(self, embeddings: np.ndarray, labels: list[bool], sample_weights: list[float]) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  """Fit the model to the provided embeddings and labels."""
114
- if len(set(labels)) < 2:
 
 
 
115
  return
116
  if len(labels) != len(embeddings):
117
  raise ValueError(
118
  f'Length of embeddings ({len(embeddings)}) must match length of labels ({len(labels)})')
119
- if len(sample_weights) != len(labels):
120
- raise ValueError(
121
- f'Length of sample_weights ({len(sample_weights)}) must match length of labels '
122
- f'({len(labels)})')
123
- self._model.fit(embeddings, labels, sample_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
@@ -136,7 +261,7 @@ def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
136
  raise ValueError(
137
  f'Draft {draft} not found in concept. Found drafts: {list(draft_examples.keys())}')
138
 
139
- # Map the text of the draft to its id so we can dedup with main.
140
  draft_text_ids = {example.text: id for id, example in draft_examples[draft].items()}
141
 
142
  # Write each of examples from main to the draft examples only if the text does not appear in the
@@ -148,7 +273,8 @@ def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
148
  return draft_examples[draft]
149
 
150
 
151
- class ConceptModel(BaseModel):
 
152
  """A concept model. Stores all concept model drafts and manages syncing."""
153
  # The concept that this model is for.
154
  namespace: str
@@ -158,20 +284,27 @@ class ConceptModel(BaseModel):
158
  embedding_name: str
159
  version: int = -1
160
 
161
- # The following fields are excluded from JSON serialization, but still pickleable.
 
 
162
  # Maps a concept id to the embeddings.
163
- _embeddings: dict[str, np.ndarray] = {}
164
- _logistic_models: dict[DraftId, LogisticEmbeddingModel] = {}
165
  _negative_vectors: Optional[np.ndarray] = None
166
 
167
- class Config:
168
- arbitrary_types_allowed = True
169
- underscore_attrs_are_private = True
 
 
 
 
 
170
 
171
- def calibrate_on_dataset(self, column_info: ConceptColumnInfo) -> None:
172
  """Calibrate the model on the embeddings in the provided vector store."""
173
  db = get_dataset(column_info.namespace, column_info.name)
174
- vector_store = db.get_vector_store(normalize_path(column_info.path))
175
  keys = vector_store.keys()
176
  num_samples = min(column_info.num_negative_examples, len(keys))
177
  sample_keys = random.sample(keys, num_samples)
@@ -199,11 +332,7 @@ class ConceptModel(BaseModel):
199
  def _get_logistic_model(self, draft: DraftId) -> LogisticEmbeddingModel:
200
  """Get the logistic model for the provided draft."""
201
  if draft not in self._logistic_models:
202
- self._logistic_models[draft] = LogisticEmbeddingModel(
203
- namespace=self.namespace,
204
- concept_name=self.concept_name,
205
- embedding_name=self.embedding_name,
206
- version=-1)
207
  return self._logistic_models[draft]
208
 
209
  def sync(self, concept: Concept) -> bool:
@@ -222,21 +351,9 @@ class ConceptModel(BaseModel):
222
  examples = draft_examples(concept, draft)
223
  embeddings = np.array([self._embeddings[id] for id in examples.keys()])
224
  labels = [example.label for example in examples.values()]
225
- num_pos_labels = len([x for x in labels if x])
226
- num_neg_labels = len([x for x in labels if not x])
227
- sample_weights = [(1.0 / num_pos_labels if x else 1.0 / num_neg_labels) for x in labels]
228
- if self._negative_vectors is not None:
229
- num_implicit_labels = len(self._negative_vectors)
230
- embeddings = np.concatenate([self._negative_vectors, embeddings])
231
- labels = [False] * num_implicit_labels + labels
232
- sample_weights = [1.0 / num_implicit_labels] * num_implicit_labels + sample_weights
233
-
234
  model = self._get_logistic_model(draft)
235
  with DebugTimer(f'Fitting model for "{concept_path}"'):
236
- model.fit(embeddings, labels, sample_weights)
237
-
238
- # Synchronize the model version with the concept version.
239
- model.version = concept.version
240
 
241
  # Synchronize the model version with the concept version.
242
  self.version = concept.version
 
1
  """Defines the concept and the concept models."""
2
+ import dataclasses
3
  import random
4
+ from enum import Enum
5
+ from typing import Callable, Iterable, Literal, Optional, Union
6
 
7
  import numpy as np
8
+ from joblib import Parallel, delayed
9
  from pydantic import BaseModel, validator
10
+ from scipy.interpolate import interp1d
11
+ from sklearn.base import BaseEstimator, clone
12
  from sklearn.exceptions import NotFittedError
13
  from sklearn.linear_model import LogisticRegression
14
+ from sklearn.metrics import precision_recall_curve, roc_auc_score
15
+ from sklearn.model_selection import KFold
16
 
17
  from ..db_manager import get_dataset
18
  from ..embeddings.embedding import get_embed_fn
19
  from ..schema import Path, RichData, SignalInputType, normalize_path
20
+ from ..signals.signal import EMBEDDING_KEY, TextEmbeddingSignal, get_signal_cls
21
  from ..utils import DebugTimer
22
 
23
  LOCAL_CONCEPT_NAMESPACE = 'local'
24
 
25
  # Number of randomly sampled negative examples to use for training. This is used to obtain a more
26
  # balanced model that works with a specific dataset.
27
+ DEFAULT_NUM_NEG_EXAMPLES = 100
28
+
29
+ # The maximum number of cross-validation models to train.
30
+ MAX_NUM_CROSS_VAL_MODELS = 15
31
+ # The β weight to use for the F-beta score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html
32
+ # β = 0.5 means we value precision 2x as much as recall.
33
+ # β = 2 means we value recall 2x as much as precision.
34
+ F_BETA_WEIGHT = 0.5
35
 
36
 
37
  class ConceptColumnInfo(BaseModel):
 
43
  # Path holding the text to use for negative examples.
44
  path: Path
45
 
46
+ @validator('path')
47
+ def _path_points_to_text_field(cls, path: Path) -> Path:
48
+ if path[-1] == EMBEDDING_KEY:
49
+ raise ValueError(
50
+ f'The path should point to the text field, not its embedding field. Provided path: {path}')
51
+ return path
52
+
53
  num_negative_examples = DEFAULT_NUM_NEG_EXAMPLES
54
 
55
 
 
92
  class Concept(BaseModel):
93
  """A concept is a collection of examples."""
94
  # The namespace of the concept.
95
+ namespace: str
96
  # The name of the concept.
97
  concept_name: str
98
  # The type of the data format that this concept represents.
 
100
  data: dict[str, Example]
101
  version: int = 0
102
 
103
+ description: Optional[str] = None
104
+
105
  def drafts(self) -> list[DraftId]:
106
  """Gets all the drafts for the concept."""
107
  drafts: set[DraftId] = set([DRAFT_MAIN]) # Always return the main draft.
 
111
  return list(sorted(drafts))
112
 
113
 
114
+ class OverallScore(str, Enum):
115
+ """Enum holding the overall score."""
116
+ NOT_GOOD = 'not_good'
117
+ OK = 'ok'
118
+ GOOD = 'good'
119
+ VERY_GOOD = 'very_good'
120
+ GREAT = 'great'
121
 
 
 
 
122
 
123
+ def _get_overall_score(f1_score: float) -> OverallScore:
124
+ if f1_score < 0.5:
125
+ return OverallScore.NOT_GOOD
126
+ if f1_score < 0.8:
127
+ return OverallScore.OK
128
+ if f1_score < 0.9:
129
+ return OverallScore.GOOD
130
+ if f1_score < 0.95:
131
+ return OverallScore.VERY_GOOD
132
+ return OverallScore.GREAT
133
+
134
+
135
+ class ConceptMetrics(BaseModel):
136
+ """Metrics for a concept."""
137
+ # The average F1 score for the concept computed using cross validation.
138
+ f1: float
139
+ precision: float
140
+ recall: float
141
+ roc_auc: float
142
+ overall: OverallScore
143
 
144
+
145
+ @dataclasses.dataclass
146
+ class LogisticEmbeddingModel:
147
+ """A model that uses logistic regression with embeddings."""
148
+
149
+ _metrics: Optional[ConceptMetrics] = None
150
+ _threshold: float = 0.5
151
+
152
+ def __post_init__(self) -> None:
153
+ # See `notebooks/Toxicity.ipynb` for an example of training a concept model.
154
+ self._model = LogisticRegression(
155
+ class_weight=None, C=30, tol=1e-5, warm_start=True, max_iter=1_000, n_jobs=-1)
156
 
157
  def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
158
  """Get the scores for the provided embeddings."""
159
  try:
160
+ y_probs = self._model.predict_proba(embeddings)[:, 1]
161
+ # Map [0, threshold, 1] to [0, 0.5, 1].
162
+ interpolate_fn = interp1d([0, self._threshold, 1], [0, 0.4999, 1])
163
+ return interpolate_fn(y_probs)
164
  except NotFittedError:
165
  return np.random.rand(len(embeddings))
166
 
167
+ def _setup_training(
168
+ self, X_train: np.ndarray, y_train: list[bool],
169
+ implicit_negatives: Optional[np.ndarray]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
170
+ num_pos_labels = len([y for y in y_train if y])
171
+ num_neg_labels = len([y for y in y_train if not y])
172
+ sample_weights = [(1.0 / num_pos_labels if y else 1.0 / num_neg_labels) for y in y_train]
173
+
174
+ if implicit_negatives is not None:
175
+ num_implicit_labels = len(implicit_negatives)
176
+ implicit_labels = [False] * num_implicit_labels
177
+ X_train = np.concatenate([implicit_negatives, X_train])
178
+ y_train = np.concatenate([implicit_labels, y_train])
179
+ sample_weights = [1.0 / num_implicit_labels] * num_implicit_labels + sample_weights
180
+
181
+ # Normalize sample weights to sum to the number of training examples.
182
+ weights = np.array(sample_weights)
183
+ weights *= (X_train.shape[0] / np.sum(weights))
184
+ return X_train, np.array(y_train), weights
185
+
186
+ def fit(self, embeddings: np.ndarray, labels: list[bool],
187
+ implicit_negatives: Optional[np.ndarray]) -> None:
188
  """Fit the model to the provided embeddings and labels."""
189
+ label_set = set(labels)
190
+ if implicit_negatives is not None:
191
+ label_set.add(False)
192
+ if len(label_set) < 2:
193
  return
194
  if len(labels) != len(embeddings):
195
  raise ValueError(
196
  f'Length of embeddings ({len(embeddings)}) must match length of labels ({len(labels)})')
197
+ X_train, y_train, sample_weights = self._setup_training(embeddings, labels, implicit_negatives)
198
+ self._model.fit(X_train, y_train, sample_weights)
199
+ self._metrics, self._threshold = self._compute_metrics(embeddings, labels, implicit_negatives)
200
+
201
+ def _compute_metrics(
202
+ self, embeddings: np.ndarray, labels: list[bool],
203
+ implicit_negatives: Optional[np.ndarray]) -> tuple[Optional[ConceptMetrics], float]:
204
+ """Return the concept metrics."""
205
+ labels = np.array(labels)
206
+ n_splits = min(len(labels), MAX_NUM_CROSS_VAL_MODELS)
207
+ fold = KFold(n_splits, shuffle=True, random_state=42)
208
+
209
+ def _fit_and_score(model: BaseEstimator, X_train: np.ndarray, y_train: np.ndarray,
210
+ sample_weights: np.ndarray, X_test: np.ndarray,
211
+ y_test: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
212
+ if len(set(y_train)) < 2:
213
+ return np.array([]), np.array([])
214
+ model.fit(X_train, y_train, sample_weights)
215
+ y_pred = model.predict_proba(X_test)[:, 1]
216
+ return y_test, y_pred
217
+
218
+ # Compute the metrics for each validation fold in parallel.
219
+ jobs: list[Callable] = []
220
+ for (train_index, test_index) in fold.split(embeddings):
221
+ X_train, y_train = embeddings[train_index], labels[train_index]
222
+ X_train, y_train, sample_weights = self._setup_training(X_train, y_train, implicit_negatives)
223
+ X_test, y_test = embeddings[test_index], labels[test_index]
224
+ model = clone(self._model)
225
+ jobs.append(delayed(_fit_and_score)(model, X_train, y_train, sample_weights, X_test, y_test))
226
+ results = Parallel(n_jobs=-1)(jobs)
227
+
228
+ y_test = np.concatenate([y_test for y_test, _ in results], axis=0)
229
+ y_pred = np.concatenate([y_pred for _, y_pred in results], axis=0)
230
+ if len(set(y_test)) < 2:
231
+ return None, 0.5
232
+ roc_auc_val = roc_auc_score(y_test, y_pred)
233
+ precision, recall, thresholds = precision_recall_curve(y_test, y_pred)
234
+ numerator = (1 + F_BETA_WEIGHT**2) * precision * recall
235
+ denom = (F_BETA_WEIGHT**2 * precision) + recall
236
+ f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom != 0))
237
+ max_f1: float = np.max(f1_scores)
238
+ max_f1_index = np.argmax(f1_scores)
239
+ max_f1_thresh: float = thresholds[max_f1_index]
240
+ max_f1_prec: float = precision[max_f1_index]
241
+ max_f1_recall: float = recall[max_f1_index]
242
+ metrics = ConceptMetrics(
243
+ f1=max_f1,
244
+ precision=max_f1_prec,
245
+ recall=max_f1_recall,
246
+ roc_auc=roc_auc_val,
247
+ overall=_get_overall_score(max_f1))
248
+ return metrics, max_f1_thresh
249
 
250
 
251
  def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
 
261
  raise ValueError(
262
  f'Draft {draft} not found in concept. Found drafts: {list(draft_examples.keys())}')
263
 
264
+ # Map the text of the draft to its id so we can dedupe with main.
265
  draft_text_ids = {example.text: id for id, example in draft_examples[draft].items()}
266
 
267
  # Write each of examples from main to the draft examples only if the text does not appear in the
 
273
  return draft_examples[draft]
274
 
275
 
276
+ @dataclasses.dataclass
277
+ class ConceptModel:
278
  """A concept model. Stores all concept model drafts and manages syncing."""
279
  # The concept that this model is for.
280
  namespace: str
 
284
  embedding_name: str
285
  version: int = -1
286
 
287
+ column_info: Optional[ConceptColumnInfo] = None
288
+
289
+ # The following fields are excluded from JSON serialization, but still pickle-able.
290
  # Maps a concept id to the embeddings.
291
+ _embeddings: dict[str, np.ndarray] = dataclasses.field(default_factory=dict)
292
+ _logistic_models: dict[DraftId, LogisticEmbeddingModel] = dataclasses.field(default_factory=dict)
293
  _negative_vectors: Optional[np.ndarray] = None
294
 
295
+ def get_metrics(self, concept: Concept) -> Optional[ConceptMetrics]:
296
+ """Return the metrics for this model."""
297
+ return self._get_logistic_model(DRAFT_MAIN)._metrics
298
+
299
+ def __post_init__(self) -> None:
300
+ if self.column_info:
301
+ self.column_info.path = normalize_path(self.column_info.path)
302
+ self._calibrate_on_dataset(self.column_info)
303
 
304
+ def _calibrate_on_dataset(self, column_info: ConceptColumnInfo) -> None:
305
  """Calibrate the model on the embeddings in the provided vector store."""
306
  db = get_dataset(column_info.namespace, column_info.name)
307
+ vector_store = db.get_vector_store(self.embedding_name, normalize_path(column_info.path))
308
  keys = vector_store.keys()
309
  num_samples = min(column_info.num_negative_examples, len(keys))
310
  sample_keys = random.sample(keys, num_samples)
 
332
  def _get_logistic_model(self, draft: DraftId) -> LogisticEmbeddingModel:
333
  """Get the logistic model for the provided draft."""
334
  if draft not in self._logistic_models:
335
+ self._logistic_models[draft] = LogisticEmbeddingModel()
 
 
 
 
336
  return self._logistic_models[draft]
337
 
338
  def sync(self, concept: Concept) -> bool:
 
351
  examples = draft_examples(concept, draft)
352
  embeddings = np.array([self._embeddings[id] for id in examples.keys()])
353
  labels = [example.label for example in examples.values()]
 
 
 
 
 
 
 
 
 
354
  model = self._get_logistic_model(draft)
355
  with DebugTimer(f'Fitting model for "{concept_path}"'):
356
+ model.fit(embeddings, labels, self._negative_vectors)
 
 
 
357
 
358
  # Synchronize the model version with the concept version.
359
  self.version = concept.version
src/concepts/db_concept.py CHANGED
@@ -2,15 +2,18 @@
2
 
3
  import abc
4
  import glob
 
5
  import os
6
  import pickle
7
  import shutil
8
 
9
  # NOTE: We have to import the module for uuid so it can be mocked.
10
  import uuid
 
11
  from typing import List, Optional, Union, cast
12
 
13
  from pydantic import BaseModel
 
14
  from typing_extensions import override
15
 
16
  from ..config import data_path
@@ -27,6 +30,7 @@ from .concept import (
27
  ExampleIn,
28
  )
29
 
 
30
  DATASET_CONCEPTS_DIR = '.concepts'
31
  CONCEPT_JSON_FILENAME = 'concept.json'
32
 
@@ -65,8 +69,19 @@ class ConceptDB(abc.ABC):
65
  pass
66
 
67
  @abc.abstractmethod
68
- def create(self, namespace: str, name: str, type: SignalInputType) -> Concept:
69
- """Create a concept."""
 
 
 
 
 
 
 
 
 
 
 
70
  pass
71
 
72
  @abc.abstractmethod
@@ -115,7 +130,7 @@ class ConceptModelDB(abc.ABC):
115
  pass
116
 
117
  @abc.abstractmethod
118
- def _save(self, model: ConceptModel, column_info: Optional[ConceptColumnInfo]) -> None:
119
  """Save the concept model."""
120
  pass
121
 
@@ -126,13 +141,14 @@ class ConceptModelDB(abc.ABC):
126
  raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
127
  return concept.version == model.version
128
 
129
- def sync(self, model: ConceptModel, column_info: Optional[ConceptColumnInfo]) -> bool:
130
  """Sync the concept model. Returns true if the model was updated."""
131
  concept = self._concept_db.get(model.namespace, model.concept_name)
132
  if not concept:
133
  raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
134
  model_updated = model.sync(concept)
135
- self._save(model, column_info)
 
136
  return model_updated
137
 
138
  @abc.abstractmethod
@@ -149,6 +165,16 @@ class ConceptModelDB(abc.ABC):
149
  """Remove all the models associated with a concept."""
150
  pass
151
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  class DiskConceptModelDB(ConceptModelDB):
154
  """Interface for the concept model database."""
@@ -191,10 +217,10 @@ class DiskConceptModelDB(ConceptModelDB):
191
  with open_file(concept_model_path, 'rb') as f:
192
  return pickle.load(f)
193
 
194
- def _save(self, model: ConceptModel, column_info: Optional[ConceptColumnInfo]) -> None:
195
  """Save the concept model."""
196
  concept_model_path = _concept_model_path(model.namespace, model.concept_name,
197
- model.embedding_name, column_info)
198
  with open_file(concept_model_path, 'wb') as f:
199
  pickle.dump(model, f)
200
 
@@ -224,10 +250,39 @@ class DiskConceptModelDB(ConceptModelDB):
224
  for dir in dirs:
225
  shutil.rmtree(dir, ignore_errors=True)
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  def _concept_output_dir(namespace: str, name: str) -> str:
229
  """Return the output directory for a given concept."""
230
- return os.path.join(data_path(), 'concept', namespace, name)
231
 
232
 
233
  def _concept_json_path(namespace: str, name: str) -> str:
@@ -246,7 +301,7 @@ def _concept_model_path(namespace: str,
246
  path_without_wildcards = (p for p in path_tuple if p != PATH_WILDCARD)
247
  path_dir = os.path.join(dataset_dir, *path_without_wildcards)
248
  return os.path.join(path_dir, DATASET_CONCEPTS_DIR, namespace, concept_name,
249
- f'{embedding_name}.pkl')
250
 
251
 
252
  class DiskConceptDB(ConceptDB):
@@ -280,16 +335,29 @@ class DiskConceptDB(ConceptDB):
280
  return None
281
 
282
  with open_file(concept_json_path) as f:
283
- return Concept.parse_raw(f.read())
 
 
 
284
 
285
  @override
286
- def create(self, namespace: str, name: str, type: SignalInputType) -> Concept:
 
 
 
 
287
  """Create a concept."""
288
  concept_json_path = _concept_json_path(namespace, name)
289
  if file_exists(concept_json_path):
290
  raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" already exists.')
291
 
292
- concept = Concept(namespace=namespace, concept_name=name, type=type, data={}, version=0)
 
 
 
 
 
 
293
  self._save(concept)
294
  return concept
295
 
 
2
 
3
  import abc
4
  import glob
5
+ import json
6
  import os
7
  import pickle
8
  import shutil
9
 
10
  # NOTE: We have to import the module for uuid so it can be mocked.
11
  import uuid
12
+ from pathlib import Path
13
  from typing import List, Optional, Union, cast
14
 
15
  from pydantic import BaseModel
16
+ from pyparsing import Any
17
  from typing_extensions import override
18
 
19
  from ..config import data_path
 
30
  ExampleIn,
31
  )
32
 
33
+ CONCEPTS_DIR = 'concept'
34
  DATASET_CONCEPTS_DIR = '.concepts'
35
  CONCEPT_JSON_FILENAME = 'concept.json'
36
 
 
69
  pass
70
 
71
  @abc.abstractmethod
72
+ def create(self,
73
+ namespace: str,
74
+ name: str,
75
+ type: SignalInputType,
76
+ description: Optional[str] = None) -> Concept:
77
+ """Create a concept.
78
+
79
+ Args:
80
+ namespace: The namespace of the concept.
81
+ name: The name of the concept.
82
+ type: The input type of the concept.
83
+ description: The description of the concept.
84
+ """
85
  pass
86
 
87
  @abc.abstractmethod
 
130
  pass
131
 
132
  @abc.abstractmethod
133
+ def _save(self, model: ConceptModel) -> None:
134
  """Save the concept model."""
135
  pass
136
 
 
141
  raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
142
  return concept.version == model.version
143
 
144
+ def sync(self, model: ConceptModel) -> bool:
145
  """Sync the concept model. Returns true if the model was updated."""
146
  concept = self._concept_db.get(model.namespace, model.concept_name)
147
  if not concept:
148
  raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
149
  model_updated = model.sync(concept)
150
+ if model_updated:
151
+ self._save(model)
152
  return model_updated
153
 
154
  @abc.abstractmethod
 
165
  """Remove all the models associated with a concept."""
166
  pass
167
 
168
+ @abc.abstractmethod
169
+ def get_models(self, namespace: str, concept_name: str) -> list[ConceptModel]:
170
+ """List all the models associated with a concept."""
171
+ pass
172
+
173
+ @abc.abstractmethod
174
+ def get_column_infos(self, namespace: str, concept_name: str) -> list[ConceptColumnInfo]:
175
+ """Get the dataset columns where this concept was applied to."""
176
+ pass
177
+
178
 
179
  class DiskConceptModelDB(ConceptModelDB):
180
  """Interface for the concept model database."""
 
217
  with open_file(concept_model_path, 'rb') as f:
218
  return pickle.load(f)
219
 
220
+ def _save(self, model: ConceptModel) -> None:
221
  """Save the concept model."""
222
  concept_model_path = _concept_model_path(model.namespace, model.concept_name,
223
+ model.embedding_name, model.column_info)
224
  with open_file(concept_model_path, 'wb') as f:
225
  pickle.dump(model, f)
226
 
 
250
  for dir in dirs:
251
  shutil.rmtree(dir, ignore_errors=True)
252
 
253
+ @override
254
+ def get_models(self, namespace: str, concept_name: str) -> list[ConceptModel]:
255
+ """List all the models associated with a concept."""
256
+ model_files = glob.iglob(os.path.join(_concept_output_dir(namespace, concept_name), '*.pkl'))
257
+ models: list[ConceptModel] = []
258
+ for model_file in model_files:
259
+ embedding_name = os.path.basename(model_file)[:-len('.pkl')]
260
+ model = self.get(namespace, concept_name, embedding_name)
261
+ if model:
262
+ models.append(model)
263
+ return models
264
+
265
+ @override
266
+ def get_column_infos(self, namespace: str, concept_name: str) -> list[ConceptColumnInfo]:
267
+ datasets_path = os.path.join(data_path(), DATASETS_DIR_NAME)
268
+ # Skip if 'datasets' doesn't exist.
269
+ if not os.path.isdir(datasets_path):
270
+ return []
271
+
272
+ dirs = glob.iglob(
273
+ os.path.join(datasets_path, '**', DATASET_CONCEPTS_DIR, namespace, concept_name),
274
+ recursive=True)
275
+ result: list[ConceptColumnInfo] = []
276
+ for dir in dirs:
277
+ dir = os.path.relpath(dir, datasets_path)
278
+ dataset_namespace, dataset_name, *path, _, _, _ = Path(dir).parts
279
+ result.append(ConceptColumnInfo(namespace=dataset_namespace, name=dataset_name, path=path))
280
+ return result
281
+
282
 
283
  def _concept_output_dir(namespace: str, name: str) -> str:
284
  """Return the output directory for a given concept."""
285
+ return os.path.join(data_path(), CONCEPTS_DIR, namespace, name)
286
 
287
 
288
  def _concept_json_path(namespace: str, name: str) -> str:
 
301
  path_without_wildcards = (p for p in path_tuple if p != PATH_WILDCARD)
302
  path_dir = os.path.join(dataset_dir, *path_without_wildcards)
303
  return os.path.join(path_dir, DATASET_CONCEPTS_DIR, namespace, concept_name,
304
+ f'{embedding_name}-neg-{column_info.num_negative_examples}.pkl')
305
 
306
 
307
  class DiskConceptDB(ConceptDB):
 
335
  return None
336
 
337
  with open_file(concept_json_path) as f:
338
+ obj: dict[str, Any] = json.load(f)
339
+ if 'namespace' not in obj:
340
+ obj['namespace'] = namespace
341
+ return Concept.parse_obj(obj)
342
 
343
  @override
344
+ def create(self,
345
+ namespace: str,
346
+ name: str,
347
+ type: SignalInputType,
348
+ description: Optional[str] = None) -> Concept:
349
  """Create a concept."""
350
  concept_json_path = _concept_json_path(namespace, name)
351
  if file_exists(concept_json_path):
352
  raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" already exists.')
353
 
354
+ concept = Concept(
355
+ namespace=namespace,
356
+ concept_name=name,
357
+ type=type,
358
+ data={},
359
+ version=0,
360
+ description=description)
361
  self._save(concept)
362
  return concept
363
 
src/concepts/db_concept_test.py CHANGED
@@ -1,7 +1,7 @@
1
  """Tests for the the database concept."""
2
 
3
  from pathlib import Path
4
- from typing import Generator, Iterable, Type, cast
5
 
6
  import numpy as np
7
  import pytest
@@ -423,7 +423,8 @@ class TestLogisticModel(LogisticEmbeddingModel):
423
  return np.array([.1])
424
 
425
  @override
426
- def fit(self, embeddings: np.ndarray, labels: list[bool], sample_weights: list[float]) -> None:
 
427
  pass
428
 
429
 
@@ -436,20 +437,24 @@ class ConceptModelDBSuite:
436
  concept_db = concept_db_cls()
437
  model_db = model_db_cls(concept_db)
438
  model = _make_test_concept_model(concept_db)
439
- model_db.sync(model, column_info=None)
440
  retrieved_model = model_db.get(
441
  namespace='test', concept_name='test_concept', embedding_name='test_embedding')
442
  if not retrieved_model:
443
  retrieved_model = model_db.create(
444
  namespace='test', concept_name='test_concept', embedding_name='test_embedding')
445
- assert retrieved_model == model
 
 
 
 
446
 
447
  def test_sync_model(self, concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB],
448
  mocker: MockerFixture) -> None:
449
 
450
  concept_db = concept_db_cls()
451
  model_db = model_db_cls(concept_db)
452
- logistic_model = TestLogisticModel(embedding_name='test_embedding')
453
  score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
454
  fit_mock = mocker.spy(TestLogisticModel, 'fit')
455
 
@@ -459,7 +464,7 @@ class ConceptModelDBSuite:
459
  assert score_embeddings_mock.call_count == 0
460
  assert fit_mock.call_count == 0
461
 
462
- model_db.sync(model, column_info=None)
463
 
464
  assert model_db.in_sync(model) is True
465
  assert score_embeddings_mock.call_count == 0
@@ -471,20 +476,20 @@ class ConceptModelDBSuite:
471
  model_db = model_db_cls(concept_db)
472
  score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
473
  fit_mock = mocker.spy(TestLogisticModel, 'fit')
474
- logistic_model = TestLogisticModel(embedding_name='test_embedding')
475
  model = _make_test_concept_model(concept_db, logistic_models={DRAFT_MAIN: logistic_model})
476
- model_db.sync(model, column_info=None)
477
  assert model_db.in_sync(model) is True
478
  assert score_embeddings_mock.call_count == 0
479
  assert fit_mock.call_count == 1
480
 
481
  (called_model, called_embeddings, called_labels,
482
- called_weights) = fit_mock.call_args_list[-1].args
483
  assert called_model == logistic_model
484
  np.testing.assert_array_equal(
485
  called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
486
  assert called_labels == [False, True]
487
- assert called_weights == [1.0, 1.0]
488
 
489
  # Edit the concept.
490
  concept_db.edit('test', 'test_concept',
@@ -495,13 +500,13 @@ class ConceptModelDBSuite:
495
  assert score_embeddings_mock.call_count == 0
496
  assert fit_mock.call_count == 1
497
 
498
- model_db.sync(model, column_info=None)
499
  assert model_db.in_sync(model) is True
500
  assert score_embeddings_mock.call_count == 0
501
  assert fit_mock.call_count == 2
502
  # Fit is called again with new points on main only.
503
  (called_model, called_embeddings, called_labels,
504
- called_weights) = fit_mock.call_args_list[-1].args
505
  assert called_model == logistic_model
506
  np.testing.assert_array_equal(
507
  called_embeddings,
@@ -510,7 +515,7 @@ class ConceptModelDBSuite:
510
  EMBEDDING_MAP['a new data point']
511
  ]))
512
  assert called_labels == [False, True, False]
513
- assert called_weights == pytest.approx([1 / 2, 1.0, 1 / 2])
514
 
515
  def test_out_of_sync_draft_model(self, concept_db_cls: Type[ConceptDB],
516
  model_db_cls: Type[ConceptModelDB],
@@ -519,14 +524,14 @@ class ConceptModelDBSuite:
519
  model_db = model_db_cls(concept_db)
520
  score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
521
  fit_mock = mocker.spy(TestLogisticModel, 'fit')
522
- logistic_model = TestLogisticModel(embedding_name='test_embedding')
523
- draft_model = TestLogisticModel(embedding_name='test_embedding')
524
  model = _make_test_concept_model(
525
  concept_db, logistic_models={
526
- DRAFT_MAIN: logistic_model,
527
  'test_draft': draft_model
528
  })
529
- model_db.sync(model, column_info=None)
530
  assert model_db.in_sync(model) is True
531
  assert score_embeddings_mock.call_count == 0
532
  assert fit_mock.call_count == 1
@@ -547,15 +552,16 @@ class ConceptModelDBSuite:
547
  assert score_embeddings_mock.call_count == 0
548
  assert fit_mock.call_count == 1
549
 
550
- model_db.sync(model, column_info=None)
551
  assert model_db.in_sync(model) is True
552
  assert score_embeddings_mock.call_count == 0
553
  assert fit_mock.call_count == 3 # Fit is called on both the draft, and main.
554
 
555
  # Fit is called again with the same points.
556
- ((called_model, called_embeddings, called_labels, called_weights),
557
- (called_draft_model, called_draft_embeddings, called_draft_labels, called_draft_weights)) = (
558
- c.args for c in fit_mock.call_args_list[-2:])
 
559
 
560
  # The draft model is called with the data from main, and the data from draft.
561
  assert called_draft_model == draft_model
@@ -572,21 +578,21 @@ class ConceptModelDBSuite:
572
  False,
573
  False
574
  ]
575
- assert called_draft_weights == pytest.approx([1.0, 1 / 3, 1 / 3, 1 / 3])
576
 
577
  # The main model was fit without the data from the draft.
578
- assert called_model == draft_model
579
  np.testing.assert_array_equal(
580
  called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
581
  assert called_labels == [False, True]
582
- assert called_weights == pytest.approx([1.0, 1.0])
583
 
584
  def test_embedding_not_found_in_map(self, concept_db_cls: Type[ConceptDB],
585
  model_db_cls: Type[ConceptModelDB]) -> None:
586
  concept_db = concept_db_cls()
587
  model_db = model_db_cls(concept_db)
588
  model = _make_test_concept_model(concept_db)
589
- model_db.sync(model, column_info=None)
590
 
591
  # Edit the concept.
592
  concept_db.edit('test', 'test_concept',
@@ -596,5 +602,5 @@ class ConceptModelDBSuite:
596
  assert model_db.in_sync(model) is False
597
 
598
  with pytest.raises(ValueError, match='Example "unknown text" not in embedding map'):
599
- model_db.sync(model, column_info=None)
600
- model_db.sync(model, column_info=None)
 
1
  """Tests for the the database concept."""
2
 
3
  from pathlib import Path
4
+ from typing import Generator, Iterable, Optional, Type, cast
5
 
6
  import numpy as np
7
  import pytest
 
423
  return np.array([.1])
424
 
425
  @override
426
+ def fit(self, embeddings: np.ndarray, labels: list[bool],
427
+ implicit_negatives: Optional[np.ndarray]) -> None:
428
  pass
429
 
430
 
 
437
  concept_db = concept_db_cls()
438
  model_db = model_db_cls(concept_db)
439
  model = _make_test_concept_model(concept_db)
440
+ model_db.sync(model)
441
  retrieved_model = model_db.get(
442
  namespace='test', concept_name='test_concept', embedding_name='test_embedding')
443
  if not retrieved_model:
444
  retrieved_model = model_db.create(
445
  namespace='test', concept_name='test_concept', embedding_name='test_embedding')
446
+ assert retrieved_model.namespace == model.namespace
447
+ assert retrieved_model.concept_name == model.concept_name
448
+ assert retrieved_model.embedding_name == model.embedding_name
449
+ assert retrieved_model.version == model.version
450
+ assert retrieved_model.column_info == model.column_info
451
 
452
  def test_sync_model(self, concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB],
453
  mocker: MockerFixture) -> None:
454
 
455
  concept_db = concept_db_cls()
456
  model_db = model_db_cls(concept_db)
457
+ logistic_model = TestLogisticModel()
458
  score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
459
  fit_mock = mocker.spy(TestLogisticModel, 'fit')
460
 
 
464
  assert score_embeddings_mock.call_count == 0
465
  assert fit_mock.call_count == 0
466
 
467
+ model_db.sync(model)
468
 
469
  assert model_db.in_sync(model) is True
470
  assert score_embeddings_mock.call_count == 0
 
476
  model_db = model_db_cls(concept_db)
477
  score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
478
  fit_mock = mocker.spy(TestLogisticModel, 'fit')
479
+ logistic_model = TestLogisticModel()
480
  model = _make_test_concept_model(concept_db, logistic_models={DRAFT_MAIN: logistic_model})
481
+ model_db.sync(model)
482
  assert model_db.in_sync(model) is True
483
  assert score_embeddings_mock.call_count == 0
484
  assert fit_mock.call_count == 1
485
 
486
  (called_model, called_embeddings, called_labels,
487
+ called_implicit_negatives) = fit_mock.call_args_list[-1].args
488
  assert called_model == logistic_model
489
  np.testing.assert_array_equal(
490
  called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
491
  assert called_labels == [False, True]
492
+ assert called_implicit_negatives is None
493
 
494
  # Edit the concept.
495
  concept_db.edit('test', 'test_concept',
 
500
  assert score_embeddings_mock.call_count == 0
501
  assert fit_mock.call_count == 1
502
 
503
+ model_db.sync(model)
504
  assert model_db.in_sync(model) is True
505
  assert score_embeddings_mock.call_count == 0
506
  assert fit_mock.call_count == 2
507
  # Fit is called again with new points on main only.
508
  (called_model, called_embeddings, called_labels,
509
+ called_implicit_negatives) = fit_mock.call_args_list[-1].args
510
  assert called_model == logistic_model
511
  np.testing.assert_array_equal(
512
  called_embeddings,
 
515
  EMBEDDING_MAP['a new data point']
516
  ]))
517
  assert called_labels == [False, True, False]
518
+ assert called_implicit_negatives is None
519
 
520
  def test_out_of_sync_draft_model(self, concept_db_cls: Type[ConceptDB],
521
  model_db_cls: Type[ConceptModelDB],
 
524
  model_db = model_db_cls(concept_db)
525
  score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
526
  fit_mock = mocker.spy(TestLogisticModel, 'fit')
527
+ main_model = TestLogisticModel()
528
+ draft_model = TestLogisticModel()
529
  model = _make_test_concept_model(
530
  concept_db, logistic_models={
531
+ DRAFT_MAIN: main_model,
532
  'test_draft': draft_model
533
  })
534
+ model_db.sync(model)
535
  assert model_db.in_sync(model) is True
536
  assert score_embeddings_mock.call_count == 0
537
  assert fit_mock.call_count == 1
 
552
  assert score_embeddings_mock.call_count == 0
553
  assert fit_mock.call_count == 1
554
 
555
+ model_db.sync(model)
556
  assert model_db.in_sync(model) is True
557
  assert score_embeddings_mock.call_count == 0
558
  assert fit_mock.call_count == 3 # Fit is called on both the draft, and main.
559
 
560
  # Fit is called again with the same points.
561
+ ((called_model, called_embeddings, called_labels, called_implicit_negatives),
562
+ (called_draft_model, called_draft_embeddings, called_draft_labels,
563
+ called_draft_implicit_negatives)) = (
564
+ c.args for c in fit_mock.call_args_list[-2:])
565
 
566
  # The draft model is called with the data from main, and the data from draft.
567
  assert called_draft_model == draft_model
 
578
  False,
579
  False
580
  ]
581
+ assert called_draft_implicit_negatives is None
582
 
583
  # The main model was fit without the data from the draft.
584
+ assert called_model == main_model
585
  np.testing.assert_array_equal(
586
  called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
587
  assert called_labels == [False, True]
588
+ assert called_implicit_negatives is None
589
 
590
  def test_embedding_not_found_in_map(self, concept_db_cls: Type[ConceptDB],
591
  model_db_cls: Type[ConceptModelDB]) -> None:
592
  concept_db = concept_db_cls()
593
  model_db = model_db_cls(concept_db)
594
  model = _make_test_concept_model(concept_db)
595
+ model_db.sync(model)
596
 
597
  # Edit the concept.
598
  concept_db.edit('test', 'test_concept',
 
602
  assert model_db.in_sync(model) is False
603
 
604
  with pytest.raises(ValueError, match='Example "unknown text" not in embedding map'):
605
+ model_db.sync(model)
606
+ model_db.sync(model)
src/data/__pycache__/dataset.cpython-39.pyc CHANGED
Binary files a/src/data/__pycache__/dataset.cpython-39.pyc and b/src/data/__pycache__/dataset.cpython-39.pyc differ
 
src/data/__pycache__/dataset_compute_signal_chain_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (9.78 kB). View file
 
src/data/__pycache__/dataset_compute_signal_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (20.7 kB). View file
 
src/data/__pycache__/dataset_duckdb.cpython-39.pyc CHANGED
Binary files a/src/data/__pycache__/dataset_duckdb.cpython-39.pyc and b/src/data/__pycache__/dataset_duckdb.cpython-39.pyc differ
 
src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.3.1.pyc CHANGED
Binary files a/src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.3.1.pyc and b/src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.3.1.pyc differ
 
src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (8.86 kB). View file
 
src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.3.1.pyc CHANGED
Binary files a/src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.3.1.pyc and b/src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.3.1.pyc differ
 
src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (6.8 kB). View file
 
src/data/__pycache__/dataset_select_rows_schema_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (17.3 kB). View file
 
src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.3.1.pyc CHANGED
Binary files a/src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.3.1.pyc and b/src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.3.1.pyc differ
 
src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (11.8 kB). View file
 
src/data/__pycache__/dataset_select_rows_sort_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (20.8 kB). View file
 
src/data/__pycache__/dataset_select_rows_udf_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (16 kB). View file
 
src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.3.1.pyc CHANGED
Binary files a/src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.3.1.pyc and b/src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.3.1.pyc differ
 
src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (5.66 kB). View file
 
src/data/__pycache__/dataset_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (21.8 kB). View file
 
src/data/__pycache__/dataset_utils.cpython-39.pyc CHANGED
Binary files a/src/data/__pycache__/dataset_utils.cpython-39.pyc and b/src/data/__pycache__/dataset_utils.cpython-39.pyc differ
 
src/data/__pycache__/dataset_utils_test.cpython-39-pytest-7.4.0.pyc ADDED
Binary file (4.22 kB). View file
 
src/data/__pycache__/duckdb_utils.cpython-39.pyc CHANGED
Binary files a/src/data/__pycache__/duckdb_utils.cpython-39.pyc and b/src/data/__pycache__/duckdb_utils.cpython-39.pyc differ
 
src/data/dataset.py CHANGED
@@ -255,7 +255,7 @@ class Dataset(abc.ABC):
255
  pass
256
 
257
  @abc.abstractmethod
258
- def get_vector_store(self, path: PathTuple) -> VectorStore:
259
  # TODO: Instead of this, allow selecting vectors via select_rows.
260
  """Get the vector store for a column."""
261
  pass
 
255
  pass
256
 
257
  @abc.abstractmethod
258
+ def get_vector_store(self, embedding: str, path: PathTuple) -> VectorStore:
259
  # TODO: Instead of this, allow selecting vectors via select_rows.
260
  """Get the vector store for a column."""
261
  pass
src/data/dataset_duckdb.py CHANGED
@@ -241,15 +241,23 @@ class DatasetDuckDB(Dataset):
241
  raise NotImplementedError('count is not yet implemented for DuckDB.')
242
 
243
  @override
244
- def get_vector_store(self, path: PathTuple) -> VectorStore:
245
  # Refresh the manifest to make sure we have the latest signal manifests.
246
  self.manifest()
247
 
 
 
 
248
  if path not in self._col_vector_stores:
249
- manifest = next(
250
- m for m in self._signal_manifests if schema_contains_path(m.data_schema, path))
251
- if not manifest:
 
 
252
  raise ValueError(f'No embedding found for path {path}.')
 
 
 
253
  if not manifest.embedding_filename_prefix:
254
  raise ValueError(f'Signal manifest for path {path} is not an embedding. '
255
  f'Got signal manifest: {manifest}')
@@ -273,7 +281,7 @@ class DatasetDuckDB(Dataset):
273
  manifest: DatasetManifest,
274
  compute_dependencies: Optional[bool] = False,
275
  task_step_id: Optional[TaskStepId] = None) -> tuple[PathTuple, Optional[TaskStepId]]:
276
- """Run all the signals depedencies required to run this signal.
277
 
278
  Args:
279
  signal: The signal to prepare.
@@ -560,7 +568,8 @@ class DatasetDuckDB(Dataset):
560
  if is_ordinal(leaf.dtype):
561
  min_max_query = f"""
562
  SELECT MIN(val) AS minVal, MAX(val) AS maxVal
563
- FROM (SELECT {inner_select} as val FROM t);
 
564
  """
565
  row = self._query(min_max_query)[0]
566
  result.min_val, result.max_val = row
@@ -590,7 +599,9 @@ class DatasetDuckDB(Dataset):
590
  named_bins = _normalize_bins(bins or leaf.bins)
591
  stats = self.stats(leaf_path)
592
 
593
- if is_float(leaf.dtype) or is_integer(leaf.dtype):
 
 
594
  if named_bins is None:
595
  # Auto-bin.
596
  named_bins = _auto_bins(stats, NUM_AUTO_BINS)
@@ -606,11 +617,14 @@ class DatasetDuckDB(Dataset):
606
  bin_index_col = 'col0'
607
  bin_min_col = 'col1'
608
  bin_max_col = 'col2'
609
- # We cast the field to `double` so bining works for both `float` and `int` fields.
 
 
610
  outer_select = f"""(
611
  SELECT {bin_index_col} FROM (
612
  VALUES {', '.join(sql_bounds)}
613
- ) WHERE {inner_val}::DOUBLE >= {bin_min_col} AND {inner_val}::DOUBLE < {bin_max_col}
 
614
  )"""
615
  else:
616
  if stats.approx_count_distinct >= dataset.TOO_MANY_DISTINCT:
@@ -625,6 +639,7 @@ class DatasetDuckDB(Dataset):
625
 
626
  filters, _ = self._normalize_filters(filters, col_aliases={}, udf_aliases={}, manifest=manifest)
627
  filter_queries = self._create_where(manifest, filters, searches=[])
 
628
  where_query = ''
629
  if filter_queries:
630
  where_query = f"WHERE {' AND '.join(filter_queries)}"
@@ -756,8 +771,9 @@ class DatasetDuckDB(Dataset):
756
  for udf_col in udf_columns:
757
  if isinstance(udf_col.signal_udf, ConceptScoreSignal):
758
  # Set dataset information on the signal.
 
759
  udf_col.signal_udf.set_column_info(
760
- ConceptColumnInfo(namespace=self.namespace, name=self.dataset_name, path=udf_col.path))
761
 
762
  # Decide on the exact sorting order.
763
  sort_results = self._merge_sorts(search_udfs, sort_by, sort_order)
@@ -791,19 +807,19 @@ class DatasetDuckDB(Dataset):
791
  if topk_udf_col:
792
  key_prefixes: Optional[list[VectorKey]] = None
793
  if where_query:
794
- # If there are filters, we need to send UUIDs to the topk query.
795
  df = con.execute(f'SELECT {UUID_COLUMN} FROM t {where_query}').df()
796
  total_num_rows = len(df)
797
  key_prefixes = df[UUID_COLUMN]
798
 
799
- signal = cast(Signal, topk_udf_col.signal_udf)
800
  # The input is an embedding.
801
- vector_store = self.get_vector_store(topk_udf_col.path)
802
  k = (limit or 0) + (offset or 0)
803
- topk = signal.vector_compute_topk(k, vector_store, key_prefixes)
804
  topk_uuids = list(dict.fromkeys([cast(str, key[0]) for key, _ in topk]))
805
 
806
- # Ignore all the other filters and filter DuckDB results only by the topk UUIDs.
807
  uuid_filter = Filter(path=(UUID_COLUMN,), op=ListOp.IN, value=topk_uuids)
808
  filter_query = self._create_where(manifest, [uuid_filter])[0]
809
  where_query = f'WHERE {filter_query}'
@@ -923,7 +939,8 @@ class DatasetDuckDB(Dataset):
923
 
924
  if signal.compute_type in [SignalInputType.TEXT_EMBEDDING]:
925
  # The input is an embedding.
926
- vector_store = self.get_vector_store(udf_col.path)
 
927
  flat_keys = flatten_keys(df[UUID_COLUMN], input)
928
  signal_out = signal.vector_compute(flat_keys, vector_store)
929
  # Add progress.
@@ -1273,39 +1290,48 @@ class DatasetDuckDB(Dataset):
1273
  binary_ops = set(BinaryOp)
1274
  unary_ops = set(UnaryOp)
1275
  list_ops = set(ListOp)
1276
- for filter in filters:
1277
- duckdb_path = self._leaf_path_to_duckdb_path(filter.path, manifest.data_schema)
1278
  select_str = _select_sql(duckdb_path, flatten=True, unnest=False)
1279
- is_array = any(subpath == PATH_WILDCARD for subpath in filter.path)
 
 
 
 
1280
 
1281
- if filter.op in binary_ops:
1282
- sql_op = BINARY_OP_TO_SQL[cast(BinaryOp, filter.op)]
1283
- filter_val = cast(FeatureValue, filter.value)
1284
  if isinstance(filter_val, str):
1285
  filter_val = f"'{filter_val}'"
1286
  elif isinstance(filter_val, bytes):
1287
  filter_val = _bytes_to_blob_literal(filter_val)
1288
  else:
1289
  filter_val = str(filter_val)
1290
- filter_query = (f'len(list_filter({select_str}, x -> x {sql_op} {filter_val})) > 0'
1291
- if is_array else f'{select_str} {sql_op} {filter_val}')
1292
- elif filter.op in unary_ops:
1293
- if filter.op == UnaryOp.EXISTS:
 
 
 
 
 
1294
  filter_query = f'len({select_str}) > 0' if is_array else f'{select_str} IS NOT NULL'
1295
  else:
1296
- raise ValueError(f'Unary op: {filter.op} is not yet supported')
1297
- elif filter.op in list_ops:
1298
- if filter.op == ListOp.IN:
1299
- filter_list_val = cast(FeatureListValue, filter.value)
1300
  if not isinstance(filter_list_val, list):
1301
  raise ValueError('filter with array value can only use the IN comparison')
1302
  wrapped_filter_val = [f"'{part}'" for part in filter_list_val]
1303
  filter_val = f'({", ".join(wrapped_filter_val)})'
1304
  filter_query = f'{select_str} IN {filter_val}'
1305
  else:
1306
- raise ValueError(f'List op: {filter.op} is not yet supported')
1307
  else:
1308
- raise ValueError(f'Invalid filter op: {filter.op}')
1309
  sql_filter_queries.append(filter_query)
1310
  return sql_filter_queries
1311
 
@@ -1330,7 +1356,7 @@ class DatasetDuckDB(Dataset):
1330
  return rows
1331
 
1332
  def _query_df(self, query: str) -> pd.DataFrame:
1333
- """Execute a query that returns a dataframe."""
1334
  result = self._execute(query)
1335
  df = _replace_nan_with_none(result.df())
1336
  result.close()
 
241
  raise NotImplementedError('count is not yet implemented for DuckDB.')
242
 
243
  @override
244
+ def get_vector_store(self, embedding: str, path: PathTuple) -> VectorStore:
245
  # Refresh the manifest to make sure we have the latest signal manifests.
246
  self.manifest()
247
 
248
+ if path[-1] != EMBEDDING_KEY:
249
+ path = (*path, embedding, PATH_WILDCARD, EMBEDDING_KEY)
250
+
251
  if path not in self._col_vector_stores:
252
+ manifests = [
253
+ m for m in self._signal_manifests
254
+ if schema_contains_path(m.data_schema, path) and m.embedding_filename_prefix
255
+ ]
256
+ if not manifests:
257
  raise ValueError(f'No embedding found for path {path}.')
258
+ if len(manifests) > 1:
259
+ raise ValueError(f'Multiple embeddings found for path {path}. Got: {manifests}')
260
+ manifest = manifests[0]
261
  if not manifest.embedding_filename_prefix:
262
  raise ValueError(f'Signal manifest for path {path} is not an embedding. '
263
  f'Got signal manifest: {manifest}')
 
281
  manifest: DatasetManifest,
282
  compute_dependencies: Optional[bool] = False,
283
  task_step_id: Optional[TaskStepId] = None) -> tuple[PathTuple, Optional[TaskStepId]]:
284
+ """Run all the signals dependencies required to run this signal.
285
 
286
  Args:
287
  signal: The signal to prepare.
 
568
  if is_ordinal(leaf.dtype):
569
  min_max_query = f"""
570
  SELECT MIN(val) AS minVal, MAX(val) AS maxVal
571
+ FROM (SELECT {inner_select} as val FROM t)
572
+ WHERE NOT isnan(val)
573
  """
574
  row = self._query(min_max_query)[0]
575
  result.min_val, result.max_val = row
 
599
  named_bins = _normalize_bins(bins or leaf.bins)
600
  stats = self.stats(leaf_path)
601
 
602
+ leaf_is_float = is_float(leaf.dtype)
603
+ leaf_is_integer = is_integer(leaf.dtype)
604
+ if leaf_is_float or leaf_is_integer:
605
  if named_bins is None:
606
  # Auto-bin.
607
  named_bins = _auto_bins(stats, NUM_AUTO_BINS)
 
617
  bin_index_col = 'col0'
618
  bin_min_col = 'col1'
619
  bin_max_col = 'col2'
620
+ is_nan_filter = f'NOT isnan({inner_val}) AND' if leaf_is_float else ''
621
+
622
+ # We cast the field to `double` so binning works for both `float` and `int` fields.
623
  outer_select = f"""(
624
  SELECT {bin_index_col} FROM (
625
  VALUES {', '.join(sql_bounds)}
626
+ ) WHERE {is_nan_filter}
627
+ {inner_val}::DOUBLE >= {bin_min_col} AND {inner_val}::DOUBLE < {bin_max_col}
628
  )"""
629
  else:
630
  if stats.approx_count_distinct >= dataset.TOO_MANY_DISTINCT:
 
639
 
640
  filters, _ = self._normalize_filters(filters, col_aliases={}, udf_aliases={}, manifest=manifest)
641
  filter_queries = self._create_where(manifest, filters, searches=[])
642
+
643
  where_query = ''
644
  if filter_queries:
645
  where_query = f"WHERE {' AND '.join(filter_queries)}"
 
771
  for udf_col in udf_columns:
772
  if isinstance(udf_col.signal_udf, ConceptScoreSignal):
773
  # Set dataset information on the signal.
774
+ source_path = udf_col.path if udf_col.path[-1] != EMBEDDING_KEY else udf_col.path[:-3]
775
  udf_col.signal_udf.set_column_info(
776
+ ConceptColumnInfo(namespace=self.namespace, name=self.dataset_name, path=source_path))
777
 
778
  # Decide on the exact sorting order.
779
  sort_results = self._merge_sorts(search_udfs, sort_by, sort_order)
 
807
  if topk_udf_col:
808
  key_prefixes: Optional[list[VectorKey]] = None
809
  if where_query:
810
+ # If there are filters, we need to send UUIDs to the top k query.
811
  df = con.execute(f'SELECT {UUID_COLUMN} FROM t {where_query}').df()
812
  total_num_rows = len(df)
813
  key_prefixes = df[UUID_COLUMN]
814
 
815
+ topk_signal = cast(TextEmbeddingModelSignal, topk_udf_col.signal_udf)
816
  # The input is an embedding.
817
+ vector_store = self.get_vector_store(topk_signal.embedding, topk_udf_col.path)
818
  k = (limit or 0) + (offset or 0)
819
+ topk = topk_signal.vector_compute_topk(k, vector_store, key_prefixes)
820
  topk_uuids = list(dict.fromkeys([cast(str, key[0]) for key, _ in topk]))
821
 
822
+ # Ignore all the other filters and filter DuckDB results only by the top k UUIDs.
823
  uuid_filter = Filter(path=(UUID_COLUMN,), op=ListOp.IN, value=topk_uuids)
824
  filter_query = self._create_where(manifest, [uuid_filter])[0]
825
  where_query = f'WHERE {filter_query}'
 
939
 
940
  if signal.compute_type in [SignalInputType.TEXT_EMBEDDING]:
941
  # The input is an embedding.
942
+ embedding_signal = cast(TextEmbeddingModelSignal, signal)
943
+ vector_store = self.get_vector_store(embedding_signal.embedding, udf_col.path)
944
  flat_keys = flatten_keys(df[UUID_COLUMN], input)
945
  signal_out = signal.vector_compute(flat_keys, vector_store)
946
  # Add progress.
 
1290
  binary_ops = set(BinaryOp)
1291
  unary_ops = set(UnaryOp)
1292
  list_ops = set(ListOp)
1293
+ for f in filters:
1294
+ duckdb_path = self._leaf_path_to_duckdb_path(f.path, manifest.data_schema)
1295
  select_str = _select_sql(duckdb_path, flatten=True, unnest=False)
1296
+ is_array = any(subpath == PATH_WILDCARD for subpath in f.path)
1297
+
1298
+ nan_filter = ''
1299
+ field = manifest.data_schema.get_field(f.path)
1300
+ filter_nans = field.dtype and is_float(field.dtype)
1301
 
1302
+ if f.op in binary_ops:
1303
+ sql_op = BINARY_OP_TO_SQL[cast(BinaryOp, f.op)]
1304
+ filter_val = cast(FeatureValue, f.value)
1305
  if isinstance(filter_val, str):
1306
  filter_val = f"'{filter_val}'"
1307
  elif isinstance(filter_val, bytes):
1308
  filter_val = _bytes_to_blob_literal(filter_val)
1309
  else:
1310
  filter_val = str(filter_val)
1311
+ if is_array:
1312
+ nan_filter = 'NOT isnan(x) AND' if filter_nans else ''
1313
+ filter_query = (f'len(list_filter({select_str}, '
1314
+ f'x -> {nan_filter} x {sql_op} {filter_val})) > 0')
1315
+ else:
1316
+ nan_filter = f'NOT isnan({select_str}) AND' if filter_nans else ''
1317
+ filter_query = f'{nan_filter} {select_str} {sql_op} {filter_val}'
1318
+ elif f.op in unary_ops:
1319
+ if f.op == UnaryOp.EXISTS:
1320
  filter_query = f'len({select_str}) > 0' if is_array else f'{select_str} IS NOT NULL'
1321
  else:
1322
+ raise ValueError(f'Unary op: {f.op} is not yet supported')
1323
+ elif f.op in list_ops:
1324
+ if f.op == ListOp.IN:
1325
+ filter_list_val = cast(FeatureListValue, f.value)
1326
  if not isinstance(filter_list_val, list):
1327
  raise ValueError('filter with array value can only use the IN comparison')
1328
  wrapped_filter_val = [f"'{part}'" for part in filter_list_val]
1329
  filter_val = f'({", ".join(wrapped_filter_val)})'
1330
  filter_query = f'{select_str} IN {filter_val}'
1331
  else:
1332
+ raise ValueError(f'List op: {f.op} is not yet supported')
1333
  else:
1334
+ raise ValueError(f'Invalid filter op: {f.op}')
1335
  sql_filter_queries.append(filter_query)
1336
  return sql_filter_queries
1337
 
 
1356
  return rows
1357
 
1358
  def _query_df(self, query: str) -> pd.DataFrame:
1359
+ """Execute a query that returns a data frame."""
1360
  result = self._execute(query)
1361
  df = _replace_nan_with_none(result.df())
1362
  result.close()
src/data/dataset_select_groups_test.py CHANGED
@@ -167,6 +167,8 @@ def test_named_bins(make_test_data: TestDataMaker) -> None:
167
  'age': 80
168
  }, {
169
  'age': 55
 
 
170
  }]
171
  dataset = make_test_data(items)
172
 
@@ -178,7 +180,7 @@ def test_named_bins(make_test_data: TestDataMaker) -> None:
178
  ('middle-aged', 50, 65),
179
  ('senior', 65, None),
180
  ])
181
- assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1)]
182
 
183
 
184
  def test_schema_with_bins(make_test_data: TestDataMaker) -> None:
@@ -192,11 +194,13 @@ def test_schema_with_bins(make_test_data: TestDataMaker) -> None:
192
  'age': 80
193
  }, {
194
  'age': 55
 
 
195
  }]
196
  data_schema = schema({
197
  UUID_COLUMN: 'string',
198
  'age': field(
199
- 'int32',
200
  bins=[
201
  ('young', None, 20),
202
  ('adult', 20, 50),
@@ -207,7 +211,7 @@ def test_schema_with_bins(make_test_data: TestDataMaker) -> None:
207
  dataset = make_test_data(items, data_schema)
208
 
209
  result = dataset.select_groups(leaf_path='age')
210
- assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1)]
211
 
212
 
213
  def test_filters(make_test_data: TestDataMaker) -> None:
@@ -304,10 +308,10 @@ def test_too_many_distinct(make_test_data: TestDataMaker, mocker: MockerFixture)
304
 
305
 
306
  def test_auto_bins_for_float(make_test_data: TestDataMaker) -> None:
307
- items: list[Item] = [{'feature': float(i)} for i in range(5)]
308
  dataset = make_test_data(items)
309
 
310
  res = dataset.select_groups('feature')
311
- assert res.counts == [('0', 1), ('3', 1), ('7', 1), ('11', 1), ('14', 1)]
312
  assert res.too_many_distinct is False
313
  assert res.bins
 
167
  'age': 80
168
  }, {
169
  'age': 55
170
+ }, {
171
+ 'age': float('nan')
172
  }]
173
  dataset = make_test_data(items)
174
 
 
180
  ('middle-aged', 50, 65),
181
  ('senior', 65, None),
182
  ])
183
+ assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1), (None, 1)]
184
 
185
 
186
  def test_schema_with_bins(make_test_data: TestDataMaker) -> None:
 
194
  'age': 80
195
  }, {
196
  'age': 55
197
+ }, {
198
+ 'age': float('nan')
199
  }]
200
  data_schema = schema({
201
  UUID_COLUMN: 'string',
202
  'age': field(
203
+ 'float32',
204
  bins=[
205
  ('young', None, 20),
206
  ('adult', 20, 50),
 
211
  dataset = make_test_data(items, data_schema)
212
 
213
  result = dataset.select_groups(leaf_path='age')
214
+ assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1), (None, 1)]
215
 
216
 
217
  def test_filters(make_test_data: TestDataMaker) -> None:
 
308
 
309
 
310
  def test_auto_bins_for_float(make_test_data: TestDataMaker) -> None:
311
+ items: list[Item] = [{'feature': float(i)} for i in range(5)] + [{'feature': float('nan')}]
312
  dataset = make_test_data(items)
313
 
314
  res = dataset.select_groups('feature')
315
+ assert res.counts == [('0', 1), ('3', 1), ('7', 1), ('11', 1), ('14', 1), (None, 1)]
316
  assert res.too_many_distinct is False
317
  assert res.bins
src/data/dataset_select_rows_filter_test.py CHANGED
@@ -24,6 +24,9 @@ TEST_DATA: list[Item] = [{
24
  'int': 2,
25
  'bool': True,
26
  'float': 1.0
 
 
 
27
  }]
28
 
29
 
@@ -46,6 +49,91 @@ def test_filter_by_ids(make_test_data: TestDataMaker) -> None:
46
  assert list(result) == []
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def test_filter_by_list_of_ids(make_test_data: TestDataMaker) -> None:
50
  dataset = make_test_data(TEST_DATA)
51
 
 
24
  'int': 2,
25
  'bool': True,
26
  'float': 1.0
27
+ }, {
28
+ UUID_COLUMN: '4',
29
+ 'float': float('nan')
30
  }]
31
 
32
 
 
49
  assert list(result) == []
50
 
51
 
52
+ def test_filter_greater(make_test_data: TestDataMaker) -> None:
53
+ dataset = make_test_data(TEST_DATA)
54
+
55
+ id_filter: BinaryFilterTuple = ('float', BinaryOp.GREATER, 2.0)
56
+ result = dataset.select_rows(filters=[id_filter])
57
+
58
+ assert list(result) == [{UUID_COLUMN: '1', 'str': 'a', 'int': 1, 'bool': False, 'float': 3.0}]
59
+
60
+
61
+ def test_filter_greater_equal(make_test_data: TestDataMaker) -> None:
62
+ dataset = make_test_data(TEST_DATA)
63
+
64
+ id_filter: BinaryFilterTuple = ('float', BinaryOp.GREATER_EQUAL, 2.0)
65
+ result = dataset.select_rows(filters=[id_filter])
66
+
67
+ assert list(result) == [{
68
+ UUID_COLUMN: '1',
69
+ 'str': 'a',
70
+ 'int': 1,
71
+ 'bool': False,
72
+ 'float': 3.0
73
+ }, {
74
+ UUID_COLUMN: '2',
75
+ 'str': 'b',
76
+ 'int': 2,
77
+ 'bool': True,
78
+ 'float': 2.0
79
+ }]
80
+
81
+
82
+ def test_filter_less(make_test_data: TestDataMaker) -> None:
83
+ dataset = make_test_data(TEST_DATA)
84
+
85
+ id_filter: BinaryFilterTuple = ('float', BinaryOp.LESS, 2.0)
86
+ result = dataset.select_rows(filters=[id_filter])
87
+
88
+ assert list(result) == [{UUID_COLUMN: '3', 'str': 'b', 'int': 2, 'bool': True, 'float': 1.0}]
89
+
90
+
91
+ def test_filter_less_equal(make_test_data: TestDataMaker) -> None:
92
+ dataset = make_test_data(TEST_DATA)
93
+
94
+ id_filter: BinaryFilterTuple = ('float', BinaryOp.LESS_EQUAL, 2.0)
95
+ result = dataset.select_rows(filters=[id_filter])
96
+
97
+ assert list(result) == [{
98
+ UUID_COLUMN: '2',
99
+ 'str': 'b',
100
+ 'int': 2,
101
+ 'bool': True,
102
+ 'float': 2.0
103
+ }, {
104
+ UUID_COLUMN: '3',
105
+ 'str': 'b',
106
+ 'int': 2,
107
+ 'bool': True,
108
+ 'float': 1.0
109
+ }]
110
+
111
+
112
+ def test_filter_not_equal(make_test_data: TestDataMaker) -> None:
113
+ dataset = make_test_data(TEST_DATA)
114
+
115
+ id_filter: BinaryFilterTuple = ('float', BinaryOp.NOT_EQUAL, 2.0)
116
+ result = dataset.select_rows(filters=[id_filter])
117
+
118
+ assert list(result) == [
119
+ {
120
+ UUID_COLUMN: '1',
121
+ 'str': 'a',
122
+ 'int': 1,
123
+ 'bool': False,
124
+ 'float': 3.0
125
+ },
126
+ {
127
+ UUID_COLUMN: '3',
128
+ 'str': 'b',
129
+ 'int': 2,
130
+ 'bool': True,
131
+ 'float': 1.0
132
+ },
133
+ # NaNs are not counted when we are filtering a field.
134
+ ]
135
+
136
+
137
  def test_filter_by_list_of_ids(make_test_data: TestDataMaker) -> None:
138
  dataset = make_test_data(TEST_DATA)
139
 
src/data/dataset_select_rows_search_test.py CHANGED
@@ -288,17 +288,9 @@ def test_concept_search(make_test_data: TestDataMaker, mocker: MockerFixture) ->
288
  },
289
  ]
290
 
291
- # Make sure fit was called with negative examples.
292
  (_, embeddings, labels, _) = concept_model_mock.call_args_list[-1].args
293
- assert embeddings.shape == (8, 3)
294
  assert labels == [
295
- # Negative implicit labels.
296
- False,
297
- False,
298
- False,
299
- False,
300
- False,
301
- False,
302
  # Explicit labels.
303
  False,
304
  True
 
288
  },
289
  ]
290
 
 
291
  (_, embeddings, labels, _) = concept_model_mock.call_args_list[-1].args
292
+ assert embeddings.shape == (2, 3)
293
  assert labels == [
 
 
 
 
 
 
 
294
  # Explicit labels.
295
  False,
296
  True
src/data/dataset_stats_test.py CHANGED
@@ -15,7 +15,7 @@ SIMPLE_ITEMS: list[Item] = [{
15
  'str': 'a',
16
  'int': 1,
17
  'bool': False,
18
- 'float': 3.0
19
  }, {
20
  UUID_COLUMN: '2',
21
  'str': 'b',
@@ -28,6 +28,9 @@ SIMPLE_ITEMS: list[Item] = [{
28
  'int': 2,
29
  'bool': True,
30
  'float': 1.0
 
 
 
31
  }]
32
 
33
 
@@ -40,7 +43,7 @@ def test_simple_stats(make_test_data: TestDataMaker) -> None:
40
 
41
  result = dataset.stats(leaf_path='float')
42
  assert result == StatsResult(
43
- path=('float',), total_count=3, approx_count_distinct=3, min_val=1.0, max_val=3.0)
44
 
45
  result = dataset.stats(leaf_path='bool')
46
  assert result == StatsResult(path=('bool',), total_count=3, approx_count_distinct=2)
 
15
  'str': 'a',
16
  'int': 1,
17
  'bool': False,
18
+ 'float': 3.0,
19
  }, {
20
  UUID_COLUMN: '2',
21
  'str': 'b',
 
28
  'int': 2,
29
  'bool': True,
30
  'float': 1.0
31
+ }, {
32
+ UUID_COLUMN: '4',
33
+ 'float': float('nan')
34
  }]
35
 
36
 
 
43
 
44
  result = dataset.stats(leaf_path='float')
45
  assert result == StatsResult(
46
+ path=('float',), total_count=4, approx_count_distinct=4, min_val=1.0, max_val=3.0)
47
 
48
  result = dataset.stats(leaf_path='bool')
49
  assert result == StatsResult(path=('bool',), total_count=3, approx_count_distinct=2)
src/data/dataset_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  """Utilities for working with datasets."""
2
 
 
3
  import math
4
  import os
5
  import pickle
@@ -283,7 +284,10 @@ def write_items_to_parquet(items: Iterable[Item], output_dir: str, schema: Schem
283
  if UUID_COLUMN not in item:
284
  item[UUID_COLUMN] = secrets.token_urlsafe(nbytes=12) # 16 base64 characters.
285
  if os.getenv('DEBUG'):
286
- _validate(item, arrow_schema)
 
 
 
287
  writer.write(item)
288
  num_items += 1
289
  writer.close()
 
1
  """Utilities for working with datasets."""
2
 
3
+ import json
4
  import math
5
  import os
6
  import pickle
 
284
  if UUID_COLUMN not in item:
285
  item[UUID_COLUMN] = secrets.token_urlsafe(nbytes=12) # 16 base64 characters.
286
  if os.getenv('DEBUG'):
287
+ try:
288
+ _validate(item, arrow_schema)
289
+ except Exception as e:
290
+ raise ValueError(f'Error validating item: {json.dumps(item)}') from e
291
  writer.write(item)
292
  num_items += 1
293
  writer.close()