nikhil_no_persistent / lilac /data /dataset_test_utils.py
nsthorat-lilac's picture
Duplicate from lilacai/nikhil_staging
bfc0ec6
raw
history blame contribute delete
No virus
5.02 kB
"""Tests utils of for dataset_test."""
import os
import pathlib
from copy import deepcopy
from datetime import datetime
from typing import Optional, Type, cast
import numpy as np
from typing_extensions import Protocol
from ..config import CONFIG_FILENAME, DatasetConfig
from ..embeddings.vector_store import VectorDBIndex
from ..schema import (
MANIFEST_FILENAME,
PARQUET_FILENAME_PREFIX,
ROWID,
VALUE_KEY,
DataType,
Field,
Item,
PathKey,
Schema,
SourceManifest,
)
from ..sources.source import Source
from ..utils import get_dataset_output_dir, open_file, to_yaml
from .dataset import Dataset, default_settings
from .dataset_utils import is_primitive, write_items_to_parquet
TEST_NAMESPACE = 'test_namespace'
TEST_DATASET_NAME = 'test_dataset'
def _infer_dtype(value: Item) -> DataType:
if isinstance(value, str):
return DataType.STRING
elif isinstance(value, bool):
return DataType.BOOLEAN
elif isinstance(value, bytes):
return DataType.BINARY
elif isinstance(value, float):
return DataType.FLOAT32
elif isinstance(value, int):
return DataType.INT32
elif isinstance(value, datetime):
return DataType.TIMESTAMP
else:
raise ValueError(f'Cannot infer dtype of primitive value: {value}')
def _infer_field(item: Item) -> Field:
"""Infer the schema from the items."""
if isinstance(item, dict):
fields: dict[str, Field] = {}
for k, v in item.items():
fields[k] = _infer_field(cast(Item, v))
dtype = None
if VALUE_KEY in fields:
dtype = fields[VALUE_KEY].dtype
del fields[VALUE_KEY]
return Field(fields=fields, dtype=dtype)
elif is_primitive(item):
return Field(dtype=_infer_dtype(item))
elif isinstance(item, list):
return Field(repeated_field=_infer_field(item[0]))
else:
raise ValueError(f'Cannot infer schema of item: {item}')
def _infer_schema(items: list[Item]) -> Schema:
"""Infer the schema from the items."""
schema = Schema(fields={})
for item in items:
field = _infer_field(item)
if not field.fields:
raise ValueError(f'Invalid schema of item. Expected an object, but got: {item}')
schema.fields = {**schema.fields, **field.fields}
return schema
class TestDataMaker(Protocol):
"""A function that creates a test dataset."""
def __call__(self, items: list[Item], schema: Optional[Schema] = None) -> Dataset:
"""Create a test dataset."""
...
class TestSource(Source):
"""Test source that does nothing."""
name = 'test_source'
def make_dataset(dataset_cls: Type[Dataset],
tmp_path: pathlib.Path,
items: list[Item],
schema: Optional[Schema] = None) -> Dataset:
"""Create a test dataset."""
schema = schema or _infer_schema(items)
_write_items(tmp_path, TEST_DATASET_NAME, items, schema)
dataset = dataset_cls(TEST_NAMESPACE, TEST_DATASET_NAME)
config = DatasetConfig(
namespace=TEST_NAMESPACE,
name=TEST_DATASET_NAME,
source=TestSource(),
settings=default_settings(dataset))
config_filepath = os.path.join(
get_dataset_output_dir(str(tmp_path), TEST_NAMESPACE, TEST_DATASET_NAME), CONFIG_FILENAME)
with open_file(config_filepath, 'w') as f:
f.write(to_yaml(config.dict(exclude_defaults=True, exclude_none=True, exclude_unset=True)))
return dataset
def _write_items(tmpdir: pathlib.Path, dataset_name: str, items: list[Item],
schema: Schema) -> None:
"""Write the items JSON to the dataset format: manifest.json and parquet files."""
source_dir = get_dataset_output_dir(str(tmpdir), TEST_NAMESPACE, dataset_name)
os.makedirs(source_dir)
# Add rowids to the items.
items = [deepcopy(item) for item in items]
for i, item in enumerate(items):
item[ROWID] = str(i + 1)
simple_parquet_files, _ = write_items_to_parquet(
items, source_dir, schema, filename_prefix=PARQUET_FILENAME_PREFIX, shard_index=0, num_shards=1)
manifest = SourceManifest(files=[simple_parquet_files], data_schema=schema)
with open_file(os.path.join(source_dir, MANIFEST_FILENAME), 'w') as f:
f.write(manifest.json(indent=2, exclude_none=True))
def enriched_item(value: Optional[Item] = None, metadata: dict[str, Item] = {}) -> Item:
"""Wrap a value in a dict with the value key."""
return {VALUE_KEY: value, **metadata}
def make_vector_index(vector_store: str, vector_dict: dict[PathKey,
list[list[float]]]) -> VectorDBIndex:
"""Make a vector index from a dictionary of vector keys to vectors."""
embeddings: list[np.ndarray] = []
spans: list[tuple[PathKey, list[tuple[int, int]]]] = []
for path_key, vectors in vector_dict.items():
vector_spans: list[tuple[int, int]] = []
for i, vector in enumerate(vectors):
embeddings.append(np.array(vector))
vector_spans.append((0, 0))
spans.append((path_key, vector_spans))
vector_index = VectorDBIndex(vector_store)
vector_index.add(spans, np.array(embeddings))
return vector_index