"""Tests utils of for dataset_test.""" import os import pathlib from copy import deepcopy from datetime import datetime from typing import Optional, Type, cast import numpy as np from typing_extensions import Protocol from ..config import CONFIG_FILENAME, DatasetConfig from ..embeddings.vector_store import VectorDBIndex from ..schema import ( MANIFEST_FILENAME, PARQUET_FILENAME_PREFIX, ROWID, VALUE_KEY, DataType, Field, Item, PathKey, Schema, SourceManifest, ) from ..sources.source import Source from ..utils import get_dataset_output_dir, open_file, to_yaml from .dataset import Dataset, default_settings from .dataset_utils import is_primitive, write_items_to_parquet TEST_NAMESPACE = 'test_namespace' TEST_DATASET_NAME = 'test_dataset' def _infer_dtype(value: Item) -> DataType: if isinstance(value, str): return DataType.STRING elif isinstance(value, bool): return DataType.BOOLEAN elif isinstance(value, bytes): return DataType.BINARY elif isinstance(value, float): return DataType.FLOAT32 elif isinstance(value, int): return DataType.INT32 elif isinstance(value, datetime): return DataType.TIMESTAMP else: raise ValueError(f'Cannot infer dtype of primitive value: {value}') def _infer_field(item: Item) -> Field: """Infer the schema from the items.""" if isinstance(item, dict): fields: dict[str, Field] = {} for k, v in item.items(): fields[k] = _infer_field(cast(Item, v)) dtype = None if VALUE_KEY in fields: dtype = fields[VALUE_KEY].dtype del fields[VALUE_KEY] return Field(fields=fields, dtype=dtype) elif is_primitive(item): return Field(dtype=_infer_dtype(item)) elif isinstance(item, list): return Field(repeated_field=_infer_field(item[0])) else: raise ValueError(f'Cannot infer schema of item: {item}') def _infer_schema(items: list[Item]) -> Schema: """Infer the schema from the items.""" schema = Schema(fields={}) for item in items: field = _infer_field(item) if not field.fields: raise ValueError(f'Invalid schema of item. Expected an object, but got: {item}') schema.fields = {**schema.fields, **field.fields} return schema class TestDataMaker(Protocol): """A function that creates a test dataset.""" def __call__(self, items: list[Item], schema: Optional[Schema] = None) -> Dataset: """Create a test dataset.""" ... class TestSource(Source): """Test source that does nothing.""" name = 'test_source' def make_dataset(dataset_cls: Type[Dataset], tmp_path: pathlib.Path, items: list[Item], schema: Optional[Schema] = None) -> Dataset: """Create a test dataset.""" schema = schema or _infer_schema(items) _write_items(tmp_path, TEST_DATASET_NAME, items, schema) dataset = dataset_cls(TEST_NAMESPACE, TEST_DATASET_NAME) config = DatasetConfig( namespace=TEST_NAMESPACE, name=TEST_DATASET_NAME, source=TestSource(), settings=default_settings(dataset)) config_filepath = os.path.join( get_dataset_output_dir(str(tmp_path), TEST_NAMESPACE, TEST_DATASET_NAME), CONFIG_FILENAME) with open_file(config_filepath, 'w') as f: f.write(to_yaml(config.dict(exclude_defaults=True, exclude_none=True, exclude_unset=True))) return dataset def _write_items(tmpdir: pathlib.Path, dataset_name: str, items: list[Item], schema: Schema) -> None: """Write the items JSON to the dataset format: manifest.json and parquet files.""" source_dir = get_dataset_output_dir(str(tmpdir), TEST_NAMESPACE, dataset_name) os.makedirs(source_dir) # Add rowids to the items. items = [deepcopy(item) for item in items] for i, item in enumerate(items): item[ROWID] = str(i + 1) simple_parquet_files, _ = write_items_to_parquet( items, source_dir, schema, filename_prefix=PARQUET_FILENAME_PREFIX, shard_index=0, num_shards=1) manifest = SourceManifest(files=[simple_parquet_files], data_schema=schema) with open_file(os.path.join(source_dir, MANIFEST_FILENAME), 'w') as f: f.write(manifest.json(indent=2, exclude_none=True)) def enriched_item(value: Optional[Item] = None, metadata: dict[str, Item] = {}) -> Item: """Wrap a value in a dict with the value key.""" return {VALUE_KEY: value, **metadata} def make_vector_index(vector_store: str, vector_dict: dict[PathKey, list[list[float]]]) -> VectorDBIndex: """Make a vector index from a dictionary of vector keys to vectors.""" embeddings: list[np.ndarray] = [] spans: list[tuple[PathKey, list[tuple[int, int]]]] = [] for path_key, vectors in vector_dict.items(): vector_spans: list[tuple[int, int]] = [] for i, vector in enumerate(vectors): embeddings.append(np.array(vector)) vector_spans.append((0, 0)) spans.append((path_key, vector_spans)) vector_index = VectorDBIndex(vector_store) vector_index.add(spans, np.array(embeddings)) return vector_index